Commit cbb6491d by Ian Lance Taylor

libgo: Update to weekly.2012-02-14 release.

From-SVN: r184798
parent ff2f581b
...@@ -88,60 +88,6 @@ Gogo::import_unsafe(const std::string& local_name, bool is_local_name_exported, ...@@ -88,60 +88,6 @@ Gogo::import_unsafe(const std::string& local_name, bool is_local_name_exported,
if (add_to_globals) if (add_to_globals)
this->add_named_object(no); this->add_named_object(no);
// Typeof.
Type* empty_interface = Type::make_empty_interface_type(bloc);
Typed_identifier_list* parameters = new Typed_identifier_list;
parameters->push_back(Typed_identifier("i", empty_interface, bloc));
results = new Typed_identifier_list;
results->push_back(Typed_identifier("", empty_interface, bloc));
fntype = Type::make_function_type(NULL, parameters, results, bloc);
no = bindings->add_function_declaration("Typeof", package, fntype, bloc);
if (add_to_globals)
this->add_named_object(no);
// Reflect.
parameters = new Typed_identifier_list;
parameters->push_back(Typed_identifier("it", empty_interface, bloc));
results = new Typed_identifier_list;
results->push_back(Typed_identifier("", empty_interface, bloc));
results->push_back(Typed_identifier("", pointer_type, bloc));
fntype = Type::make_function_type(NULL, parameters, results, bloc);
no = bindings->add_function_declaration("Reflect", package, fntype, bloc);
if (add_to_globals)
this->add_named_object(no);
// Unreflect.
parameters = new Typed_identifier_list;
parameters->push_back(Typed_identifier("typ", empty_interface, bloc));
parameters->push_back(Typed_identifier("addr", pointer_type, bloc));
results = new Typed_identifier_list;
results->push_back(Typed_identifier("", empty_interface, bloc));
fntype = Type::make_function_type(NULL, parameters, results, bloc);
no = bindings->add_function_declaration("Unreflect", package, fntype, bloc);
if (add_to_globals)
this->add_named_object(no);
// New.
parameters = new Typed_identifier_list;
parameters->push_back(Typed_identifier("typ", empty_interface, bloc));
results = new Typed_identifier_list;
results->push_back(Typed_identifier("", pointer_type, bloc));
fntype = Type::make_function_type(NULL, parameters, results, bloc);
no = bindings->add_function_declaration("New", package, fntype, bloc);
if (add_to_globals)
this->add_named_object(no);
// NewArray.
parameters = new Typed_identifier_list;
parameters->push_back(Typed_identifier("typ", empty_interface, bloc));
parameters->push_back(Typed_identifier("n", int_type, bloc));
results = new Typed_identifier_list;
results->push_back(Typed_identifier("", pointer_type, bloc));
fntype = Type::make_function_type(NULL, parameters, results, bloc);
no = bindings->add_function_declaration("NewArray", package, fntype, bloc);
if (add_to_globals)
this->add_named_object(no);
if (!this->imported_unsafe_) if (!this->imported_unsafe_)
{ {
go_imported_unsafe(); go_imported_unsafe();
......
...@@ -9,13 +9,13 @@ ...@@ -9,13 +9,13 @@
package main package main
import "runtime" import "runtime"
func foo(runtime.UintType, i int) { // ERROR "cannot declare name runtime.UintType|named/anonymous mix" func foo(runtime.UintType, i int) { // ERROR "cannot declare name runtime.UintType|named/anonymous mix|undefined identifier"
println(i, runtime.UintType) println(i, runtime.UintType) // GCCGO_ERROR "undefined identifier"
} }
func bar(i int) { func bar(i int) {
runtime.UintType := i // ERROR "cannot declare name runtime.UintType|non-name on left side" runtime.UintType := i // ERROR "cannot declare name runtime.UintType|non-name on left side|undefined identifier"
println(runtime.UintType) // GCCGO_ERROR "invalid use of type" println(runtime.UintType) // GCCGO_ERROR "invalid use of type|undefined identifier"
} }
func baz() { func baz() {
......
52ba9506bd99 43cf9b39b647
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -225,7 +225,6 @@ toolexeclibgoexp_DATA = \ ...@@ -225,7 +225,6 @@ toolexeclibgoexp_DATA = \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \ exp/proxy.gox \
exp/signal.gox \
exp/terminal.gox \ exp/terminal.gox \
exp/types.gox \ exp/types.gox \
exp/utf8string.gox exp/utf8string.gox
...@@ -325,6 +324,7 @@ toolexeclibgoosdir = $(toolexeclibgodir)/os ...@@ -325,6 +324,7 @@ toolexeclibgoosdir = $(toolexeclibgodir)/os
toolexeclibgoos_DATA = \ toolexeclibgoos_DATA = \
os/exec.gox \ os/exec.gox \
os/signal.gox \
os/user.gox os/user.gox
toolexeclibgopathdir = $(toolexeclibgodir)/path toolexeclibgopathdir = $(toolexeclibgodir)/path
...@@ -352,8 +352,7 @@ toolexeclibgotestingdir = $(toolexeclibgodir)/testing ...@@ -352,8 +352,7 @@ toolexeclibgotestingdir = $(toolexeclibgodir)/testing
toolexeclibgotesting_DATA = \ toolexeclibgotesting_DATA = \
testing/iotest.gox \ testing/iotest.gox \
testing/quick.gox \ testing/quick.gox
testing/script.gox
toolexeclibgotextdir = $(toolexeclibgodir)/text toolexeclibgotextdir = $(toolexeclibgodir)/text
...@@ -398,6 +397,7 @@ runtime_files = \ ...@@ -398,6 +397,7 @@ runtime_files = \
runtime/go-byte-array-to-string.c \ runtime/go-byte-array-to-string.c \
runtime/go-breakpoint.c \ runtime/go-breakpoint.c \
runtime/go-caller.c \ runtime/go-caller.c \
runtime/go-callers.c \
runtime/go-can-convert-interface.c \ runtime/go-can-convert-interface.c \
runtime/go-cgo.c \ runtime/go-cgo.c \
runtime/go-check-interface.c \ runtime/go-check-interface.c \
...@@ -428,7 +428,6 @@ runtime_files = \ ...@@ -428,7 +428,6 @@ runtime_files = \
runtime/go-panic.c \ runtime/go-panic.c \
runtime/go-print.c \ runtime/go-print.c \
runtime/go-recover.c \ runtime/go-recover.c \
runtime/go-reflect.c \
runtime/go-reflect-call.c \ runtime/go-reflect-call.c \
runtime/go-reflect-map.c \ runtime/go-reflect-map.c \
runtime/go-rune.c \ runtime/go-rune.c \
...@@ -450,7 +449,6 @@ runtime_files = \ ...@@ -450,7 +449,6 @@ runtime_files = \
runtime/go-type-string.c \ runtime/go-type-string.c \
runtime/go-typedesc-equal.c \ runtime/go-typedesc-equal.c \
runtime/go-typestring.c \ runtime/go-typestring.c \
runtime/go-unreflect.c \
runtime/go-unsafe-new.c \ runtime/go-unsafe-new.c \
runtime/go-unsafe-newarray.c \ runtime/go-unsafe-newarray.c \
runtime/go-unsafe-pointer.c \ runtime/go-unsafe-pointer.c \
...@@ -468,6 +466,7 @@ runtime_files = \ ...@@ -468,6 +466,7 @@ runtime_files = \
runtime/msize.c \ runtime/msize.c \
runtime/proc.c \ runtime/proc.c \
runtime/runtime.c \ runtime/runtime.c \
runtime/signal_unix.c \
runtime/thread.c \ runtime/thread.c \
runtime/yield.c \ runtime/yield.c \
$(rtems_task_variable_add_file) \ $(rtems_task_variable_add_file) \
...@@ -509,7 +508,7 @@ sema.c: $(srcdir)/runtime/sema.goc goc2c ...@@ -509,7 +508,7 @@ sema.c: $(srcdir)/runtime/sema.goc goc2c
mv -f $@.tmp $@ mv -f $@.tmp $@
sigqueue.c: $(srcdir)/runtime/sigqueue.goc goc2c sigqueue.c: $(srcdir)/runtime/sigqueue.goc goc2c
./goc2c --gcc --go-prefix libgo_runtime $< > $@.tmp ./goc2c --gcc --go-prefix libgo_os $< > $@.tmp
mv -f $@.tmp $@ mv -f $@.tmp $@
time.c: $(srcdir)/runtime/time.goc goc2c time.c: $(srcdir)/runtime/time.goc goc2c
...@@ -526,7 +525,8 @@ go_bufio_files = \ ...@@ -526,7 +525,8 @@ go_bufio_files = \
go_bytes_files = \ go_bytes_files = \
go/bytes/buffer.go \ go/bytes/buffer.go \
go/bytes/bytes.go \ go/bytes/bytes.go \
go/bytes/bytes_decl.go go/bytes/bytes_decl.go \
go/bytes/reader.go
go_bytes_c_files = \ go_bytes_c_files = \
go/bytes/indexbyte.c go/bytes/indexbyte.c
...@@ -784,9 +784,7 @@ go_os_files = \ ...@@ -784,9 +784,7 @@ go_os_files = \
$(go_os_stat_file) \ $(go_os_stat_file) \
go/os/str.go \ go/os/str.go \
$(go_os_sys_file) \ $(go_os_sys_file) \
go/os/time.go \ go/os/types.go
go/os/types.go \
signal_unix.go
go_path_files = \ go_path_files = \
go/path/match.go \ go/path/match.go \
...@@ -811,7 +809,6 @@ go_runtime_files = \ ...@@ -811,7 +809,6 @@ go_runtime_files = \
go/runtime/error.go \ go/runtime/error.go \
go/runtime/extern.go \ go/runtime/extern.go \
go/runtime/mem.go \ go/runtime/mem.go \
go/runtime/sig.go \
go/runtime/softfloat64.go \ go/runtime/softfloat64.go \
go/runtime/type.go \ go/runtime/type.go \
version.go version.go
...@@ -1103,8 +1100,6 @@ go_exp_proxy_files = \ ...@@ -1103,8 +1100,6 @@ go_exp_proxy_files = \
go/exp/proxy/per_host.go \ go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \ go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go go/exp/proxy/socks5.go
go_exp_signal_files = \
go/exp/signal/signal.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/terminal.go \ go/exp/terminal/terminal.go \
go/exp/terminal/util.go go/exp/terminal/util.go
...@@ -1302,6 +1297,10 @@ go_os_exec_files = \ ...@@ -1302,6 +1297,10 @@ go_os_exec_files = \
go/os/exec/exec.go \ go/os/exec/exec.go \
go/os/exec/lp_unix.go go/os/exec/lp_unix.go
go_os_signal_files = \
go/os/signal/signal.go \
go/os/signal/signal_unix.go
go_os_user_files = \ go_os_user_files = \
go/os/user/user.go \ go/os/user/user.go \
go/os/user/lookup_unix.go go/os/user/lookup_unix.go
...@@ -1352,8 +1351,6 @@ go_testing_iotest_files = \ ...@@ -1352,8 +1351,6 @@ go_testing_iotest_files = \
go/testing/iotest/writer.go go/testing/iotest/writer.go
go_testing_quick_files = \ go_testing_quick_files = \
go/testing/quick/quick.go go/testing/quick/quick.go
go_testing_script_files = \
go/testing/script/script.go
go_text_scanner_files = \ go_text_scanner_files = \
go/text/scanner/scanner.go go/text/scanner/scanner.go
...@@ -1529,6 +1526,7 @@ go_syscall_files = \ ...@@ -1529,6 +1526,7 @@ go_syscall_files = \
syscall_arch.go syscall_arch.go
go_syscall_c_files = \ go_syscall_c_files = \
go/syscall/errno.c \ go/syscall/errno.c \
go/syscall/signame.c \
$(syscall_wait_c_file) $(syscall_wait_c_file)
libcalls.go: s-libcalls; @true libcalls.go: s-libcalls; @true
...@@ -1667,7 +1665,6 @@ libgo_go_objs = \ ...@@ -1667,7 +1665,6 @@ libgo_go_objs = \
exp/html.lo \ exp/html.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \ exp/proxy.lo \
exp/signal.lo \
exp/terminal.lo \ exp/terminal.lo \
exp/types.lo \ exp/types.lo \
exp/utf8string.lo \ exp/utf8string.lo \
...@@ -1712,6 +1709,7 @@ libgo_go_objs = \ ...@@ -1712,6 +1709,7 @@ libgo_go_objs = \
old/regexp.lo \ old/regexp.lo \
old/template.lo \ old/template.lo \
$(os_lib_inotify_lo) \ $(os_lib_inotify_lo) \
os/signal.lo \
os/user.lo \ os/user.lo \
path/filepath.lo \ path/filepath.lo \
regexp/syntax.lo \ regexp/syntax.lo \
...@@ -1722,6 +1720,7 @@ libgo_go_objs = \ ...@@ -1722,6 +1720,7 @@ libgo_go_objs = \
sync/atomic_c.lo \ sync/atomic_c.lo \
syscall/syscall.lo \ syscall/syscall.lo \
syscall/errno.lo \ syscall/errno.lo \
syscall/signame.lo \
syscall/wait.lo \ syscall/wait.lo \
text/scanner.lo \ text/scanner.lo \
text/tabwriter.lo \ text/tabwriter.lo \
...@@ -1730,7 +1729,6 @@ libgo_go_objs = \ ...@@ -1730,7 +1729,6 @@ libgo_go_objs = \
testing/testing.lo \ testing/testing.lo \
testing/iotest.lo \ testing/iotest.lo \
testing/quick.lo \ testing/quick.lo \
testing/script.lo \
unicode/utf16.lo \ unicode/utf16.lo \
unicode/utf8.lo unicode/utf8.lo
...@@ -1986,10 +1984,6 @@ os/check: $(CHECK_DEPS) ...@@ -1986,10 +1984,6 @@ os/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: os/check .PHONY: os/check
signal_unix.go: $(srcdir)/go/os/mkunixsignals.sh sysinfo.go
$(SHELL) $(srcdir)/go/os/mkunixsignals.sh sysinfo.go > $@.tmp
mv -f $@.tmp $@
@go_include@ path/path.lo.dep @go_include@ path/path.lo.dep
path/path.lo.dep: $(go_path_files) path/path.lo.dep: $(go_path_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -2599,16 +2593,6 @@ exp/proxy/check: $(CHECK_DEPS) ...@@ -2599,16 +2593,6 @@ exp/proxy/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/proxy/check .PHONY: exp/proxy/check
@go_include@ exp/signal.lo.dep
exp/signal.lo.dep: $(go_exp_signal_files)
$(BUILDDEPS)
exp/signal.lo: $(go_exp_signal_files)
$(BUILDPACKAGE)
exp/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/signal
@$(CHECK)
.PHONY: exp/signal/check
@go_include@ exp/terminal.lo.dep @go_include@ exp/terminal.lo.dep
exp/terminal.lo.dep: $(go_exp_terminal_files) exp/terminal.lo.dep: $(go_exp_terminal_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3060,6 +3044,16 @@ os/exec/check: $(CHECK_DEPS) ...@@ -3060,6 +3044,16 @@ os/exec/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: os/exec/check .PHONY: os/exec/check
@go_include@ os/signal.lo.dep
os/signal.lo.dep: $(go_os_signal_files)
$(BUILDDEPS)
os/signal.lo: $(go_os_signal_files)
$(BUILDPACKAGE)
os/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) os/signal
@$(CHECK)
.PHONY: os/signal/check
@go_include@ os/user.lo.dep @go_include@ os/user.lo.dep
os/user.lo.dep: $(go_os_user_files) os/user.lo.dep: $(go_os_user_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3171,16 +3165,6 @@ testing/quick/check: $(CHECK_DEPS) ...@@ -3171,16 +3165,6 @@ testing/quick/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: testing/quick/check .PHONY: testing/quick/check
@go_include@ testing/script.lo.dep
testing/script.lo.dep: $(go_testing_script_files)
$(BUILDDEPS)
testing/script.lo: $(go_testing_script_files)
$(BUILDPACKAGE)
testing/script/check: $(CHECK_DEPS)
@$(MKDIR_P) testing/script
@$(CHECK)
.PHONY: testing/script/check
@go_include@ unicode/utf16.lo.dep @go_include@ unicode/utf16.lo.dep
unicode/utf16.lo.dep: $(go_unicode_utf16_files) unicode/utf16.lo.dep: $(go_unicode_utf16_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3208,6 +3192,8 @@ syscall/syscall.lo: $(go_syscall_files) ...@@ -3208,6 +3192,8 @@ syscall/syscall.lo: $(go_syscall_files)
$(BUILDPACKAGE) $(BUILDPACKAGE)
syscall/errno.lo: go/syscall/errno.c syscall/errno.lo: go/syscall/errno.c
$(LTCOMPILE) -c -o $@ $< $(LTCOMPILE) -c -o $@ $<
syscall/signame.lo: go/syscall/signame.c
$(LTCOMPILE) -c -o $@ $<
syscall/wait.lo: go/syscall/wait.c syscall/wait.lo: go/syscall/wait.c
$(LTCOMPILE) -c -o $@ $< $(LTCOMPILE) -c -o $@ $<
...@@ -3384,8 +3370,6 @@ exp/norm.gox: exp/norm.lo ...@@ -3384,8 +3370,6 @@ exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo exp/proxy.gox: exp/proxy.lo
$(BUILDGOX) $(BUILDGOX)
exp/signal.gox: exp/signal.lo
$(BUILDGOX)
exp/terminal.gox: exp/terminal.lo exp/terminal.gox: exp/terminal.lo
$(BUILDGOX) $(BUILDGOX)
exp/types.gox: exp/types.lo exp/types.gox: exp/types.lo
...@@ -3486,6 +3470,8 @@ old/template.gox: old/template.lo ...@@ -3486,6 +3470,8 @@ old/template.gox: old/template.lo
os/exec.gox: os/exec.lo os/exec.gox: os/exec.lo
$(BUILDGOX) $(BUILDGOX)
os/signal.gox: os/signal.lo
$(BUILDGOX)
os/user.gox: os/user.lo os/user.gox: os/user.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -3516,8 +3502,6 @@ testing/iotest.gox: testing/iotest.lo ...@@ -3516,8 +3502,6 @@ testing/iotest.gox: testing/iotest.lo
$(BUILDGOX) $(BUILDGOX)
testing/quick.gox: testing/quick.lo testing/quick.gox: testing/quick.lo
$(BUILDGOX) $(BUILDGOX)
testing/script.gox: testing/script.lo
$(BUILDGOX)
unicode/utf16.gox: unicode/utf16.lo unicode/utf16.gox: unicode/utf16.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -3605,7 +3589,6 @@ TEST_PACKAGES = \ ...@@ -3605,7 +3589,6 @@ TEST_PACKAGES = \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \ exp/proxy/check \
exp/signal/check \
exp/terminal/check \ exp/terminal/check \
exp/utf8string/check \ exp/utf8string/check \
html/template/check \ html/template/check \
...@@ -3635,6 +3618,7 @@ TEST_PACKAGES = \ ...@@ -3635,6 +3618,7 @@ TEST_PACKAGES = \
net/http/check \ net/http/check \
net/http/cgi/check \ net/http/cgi/check \
net/http/fcgi/check \ net/http/fcgi/check \
net/http/httptest/check \
net/http/httputil/check \ net/http/httputil/check \
net/mail/check \ net/mail/check \
net/rpc/check \ net/rpc/check \
...@@ -3646,6 +3630,7 @@ TEST_PACKAGES = \ ...@@ -3646,6 +3630,7 @@ TEST_PACKAGES = \
old/regexp/check \ old/regexp/check \
old/template/check \ old/template/check \
os/exec/check \ os/exec/check \
os/signal/check \
os/user/check \ os/user/check \
path/filepath/check \ path/filepath/check \
regexp/syntax/check \ regexp/syntax/check \
...@@ -3655,7 +3640,6 @@ TEST_PACKAGES = \ ...@@ -3655,7 +3640,6 @@ TEST_PACKAGES = \
text/template/check \ text/template/check \
text/template/parse/check \ text/template/parse/check \
testing/quick/check \ testing/quick/check \
testing/script/check \
unicode/utf16/check \ unicode/utf16/check \
unicode/utf8/check unicode/utf8/check
......
...@@ -278,7 +278,7 @@ func TestInvalidFiles(t *testing.T) { ...@@ -278,7 +278,7 @@ func TestInvalidFiles(t *testing.T) {
b := make([]byte, size) b := make([]byte, size)
// zeroes // zeroes
_, err := NewReader(sliceReaderAt(b), size) _, err := NewReader(bytes.NewReader(b), size)
if err != ErrFormat { if err != ErrFormat {
t.Errorf("zeroes: error=%v, want %v", err, ErrFormat) t.Errorf("zeroes: error=%v, want %v", err, ErrFormat)
} }
...@@ -289,15 +289,8 @@ func TestInvalidFiles(t *testing.T) { ...@@ -289,15 +289,8 @@ func TestInvalidFiles(t *testing.T) {
for i := 0; i < size-4; i += 4 { for i := 0; i < size-4; i += 4 {
copy(b[i:i+4], sig) copy(b[i:i+4], sig)
} }
_, err = NewReader(sliceReaderAt(b), size) _, err = NewReader(bytes.NewReader(b), size)
if err != ErrFormat { if err != ErrFormat {
t.Errorf("sigs: error=%v, want %v", err, ErrFormat) t.Errorf("sigs: error=%v, want %v", err, ErrFormat)
} }
} }
type sliceReaderAt []byte
func (r sliceReaderAt) ReadAt(b []byte, off int64) (int, error) {
copy(b, r[int(off):int(off)+len(b)])
return len(b), nil
}
...@@ -19,7 +19,7 @@ import ( ...@@ -19,7 +19,7 @@ import (
// Writer implements a zip file writer. // Writer implements a zip file writer.
type Writer struct { type Writer struct {
countWriter cw *countWriter
dir []*header dir []*header
last *fileWriter last *fileWriter
closed bool closed bool
...@@ -32,7 +32,7 @@ type header struct { ...@@ -32,7 +32,7 @@ type header struct {
// NewWriter returns a new Writer writing a zip file to w. // NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer { func NewWriter(w io.Writer) *Writer {
return &Writer{countWriter: countWriter{w: bufio.NewWriter(w)}} return &Writer{cw: &countWriter{w: bufio.NewWriter(w)}}
} }
// Close finishes writing the zip file by writing the central directory. // Close finishes writing the zip file by writing the central directory.
...@@ -52,42 +52,42 @@ func (w *Writer) Close() (err error) { ...@@ -52,42 +52,42 @@ func (w *Writer) Close() (err error) {
defer recoverError(&err) defer recoverError(&err)
// write central directory // write central directory
start := w.count start := w.cw.count
for _, h := range w.dir { for _, h := range w.dir {
write(w, uint32(directoryHeaderSignature)) write(w.cw, uint32(directoryHeaderSignature))
write(w, h.CreatorVersion) write(w.cw, h.CreatorVersion)
write(w, h.ReaderVersion) write(w.cw, h.ReaderVersion)
write(w, h.Flags) write(w.cw, h.Flags)
write(w, h.Method) write(w.cw, h.Method)
write(w, h.ModifiedTime) write(w.cw, h.ModifiedTime)
write(w, h.ModifiedDate) write(w.cw, h.ModifiedDate)
write(w, h.CRC32) write(w.cw, h.CRC32)
write(w, h.CompressedSize) write(w.cw, h.CompressedSize)
write(w, h.UncompressedSize) write(w.cw, h.UncompressedSize)
write(w, uint16(len(h.Name))) write(w.cw, uint16(len(h.Name)))
write(w, uint16(len(h.Extra))) write(w.cw, uint16(len(h.Extra)))
write(w, uint16(len(h.Comment))) write(w.cw, uint16(len(h.Comment)))
write(w, uint16(0)) // disk number start write(w.cw, uint16(0)) // disk number start
write(w, uint16(0)) // internal file attributes write(w.cw, uint16(0)) // internal file attributes
write(w, h.ExternalAttrs) write(w.cw, h.ExternalAttrs)
write(w, h.offset) write(w.cw, h.offset)
writeBytes(w, []byte(h.Name)) writeBytes(w.cw, []byte(h.Name))
writeBytes(w, h.Extra) writeBytes(w.cw, h.Extra)
writeBytes(w, []byte(h.Comment)) writeBytes(w.cw, []byte(h.Comment))
} }
end := w.count end := w.cw.count
// write end record // write end record
write(w, uint32(directoryEndSignature)) write(w.cw, uint32(directoryEndSignature))
write(w, uint16(0)) // disk number write(w.cw, uint16(0)) // disk number
write(w, uint16(0)) // disk number where directory starts write(w.cw, uint16(0)) // disk number where directory starts
write(w, uint16(len(w.dir))) // number of entries this disk write(w.cw, uint16(len(w.dir))) // number of entries this disk
write(w, uint16(len(w.dir))) // number of entries total write(w.cw, uint16(len(w.dir))) // number of entries total
write(w, uint32(end-start)) // size of directory write(w.cw, uint32(end-start)) // size of directory
write(w, uint32(start)) // start of directory write(w.cw, uint32(start)) // start of directory
write(w, uint16(0)) // size of comment write(w.cw, uint16(0)) // size of comment
return w.w.(*bufio.Writer).Flush() return w.cw.w.(*bufio.Writer).Flush()
} }
// Create adds a file to the zip file using the provided name. // Create adds a file to the zip file using the provided name.
...@@ -119,15 +119,19 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) { ...@@ -119,15 +119,19 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
fh.ReaderVersion = 0x14 fh.ReaderVersion = 0x14
fw := &fileWriter{ fw := &fileWriter{
zipw: w, zipw: w.cw,
compCount: &countWriter{w: w}, compCount: &countWriter{w: w.cw},
crc32: crc32.NewIEEE(), crc32: crc32.NewIEEE(),
} }
switch fh.Method { switch fh.Method {
case Store: case Store:
fw.comp = nopCloser{fw.compCount} fw.comp = nopCloser{fw.compCount}
case Deflate: case Deflate:
fw.comp = flate.NewWriter(fw.compCount, 5) var err error
fw.comp, err = flate.NewWriter(fw.compCount, 5)
if err != nil {
return nil, err
}
default: default:
return nil, ErrAlgorithm return nil, ErrAlgorithm
} }
...@@ -135,12 +139,12 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) { ...@@ -135,12 +139,12 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
h := &header{ h := &header{
FileHeader: fh, FileHeader: fh,
offset: uint32(w.count), offset: uint32(w.cw.count),
} }
w.dir = append(w.dir, h) w.dir = append(w.dir, h)
fw.header = h fw.header = h
if err := writeHeader(w, fh); err != nil { if err := writeHeader(w.cw, fh); err != nil {
return nil, err return nil, err
} }
......
...@@ -77,7 +77,7 @@ func TestWriter(t *testing.T) { ...@@ -77,7 +77,7 @@ func TestWriter(t *testing.T) {
} }
// read it back // read it back
r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len())) r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -9,22 +9,12 @@ package zip ...@@ -9,22 +9,12 @@ package zip
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
) )
type stringReaderAt string
func (s stringReaderAt) ReadAt(p []byte, off int64) (n int, err error) {
if off >= int64(len(s)) {
return 0, io.EOF
}
n = copy(p, s[off:])
return
}
func TestOver65kFiles(t *testing.T) { func TestOver65kFiles(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Logf("slow test; skipping") t.Logf("slow test; skipping")
...@@ -42,8 +32,8 @@ func TestOver65kFiles(t *testing.T) { ...@@ -42,8 +32,8 @@ func TestOver65kFiles(t *testing.T) {
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
t.Fatalf("Writer.Close: %v", err) t.Fatalf("Writer.Close: %v", err)
} }
rat := stringReaderAt(buf.String()) s := buf.String()
zr, err := NewReader(rat, int64(len(rat))) zr, err := NewReader(strings.NewReader(s), int64(len(s)))
if err != nil { if err != nil {
t.Fatalf("NewReader: %v", err) t.Fatalf("NewReader: %v", err)
} }
......
...@@ -182,14 +182,21 @@ func makeSlice(n int) []byte { ...@@ -182,14 +182,21 @@ func makeSlice(n int) []byte {
func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) { func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) {
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off < len(b.buf) { if b.off < len(b.buf) {
nBytes := b.Len()
m, e := w.Write(b.buf[b.off:]) m, e := w.Write(b.buf[b.off:])
if m > nBytes {
panic("bytes.Buffer.WriteTo: invalid Write count")
}
b.off += m b.off += m
n = int64(m) n = int64(m)
if e != nil { if e != nil {
return n, e return n, e
} }
// otherwise all bytes were written, by definition of // all bytes should have been written, by definition of
// Write method in io.Writer // Write method in io.Writer
if m != nBytes {
return n, io.ErrShortWrite
}
} }
// Buffer is now empty; reset. // Buffer is now empty; reset.
b.Truncate(0) b.Truncate(0)
......
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bytes
import (
"errors"
"io"
"unicode/utf8"
)
// A Reader implements the io.Reader, io.ReaderAt, io.Seeker,
// io.ByteScanner, and io.RuneScanner interfaces by reading from
// a byte slice.
// Unlike a Buffer, a Reader is read-only and supports seeking.
type Reader struct {
s []byte
i int // current reading index
prevRune int // index of previous rune; or < 0
}
// Len returns the number of bytes of the unread portion of the
// slice.
func (r *Reader) Len() int {
if r.i >= len(r.s) {
return 0
}
return len(r.s) - r.i
}
func (r *Reader) Read(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
if r.i >= len(r.s) {
return 0, io.EOF
}
n = copy(b, r.s[r.i:])
r.i += n
r.prevRune = -1
return
}
func (r *Reader) ReadAt(b []byte, off int64) (n int, err error) {
if off < 0 {
return 0, errors.New("bytes: invalid offset")
}
if off >= int64(len(r.s)) {
return 0, io.EOF
}
n = copy(b, r.s[int(off):])
if n < len(b) {
err = io.EOF
}
return
}
func (r *Reader) ReadByte() (b byte, err error) {
if r.i >= len(r.s) {
return 0, io.EOF
}
b = r.s[r.i]
r.i++
r.prevRune = -1
return
}
func (r *Reader) UnreadByte() error {
if r.i <= 0 {
return errors.New("bytes.Reader: at beginning of slice")
}
r.i--
r.prevRune = -1
return nil
}
func (r *Reader) ReadRune() (ch rune, size int, err error) {
if r.i >= len(r.s) {
return 0, 0, io.EOF
}
r.prevRune = r.i
if c := r.s[r.i]; c < utf8.RuneSelf {
r.i++
return rune(c), 1, nil
}
ch, size = utf8.DecodeRune(r.s[r.i:])
r.i += size
return
}
func (r *Reader) UnreadRune() error {
if r.prevRune < 0 {
return errors.New("bytes.Reader: previous operation was not ReadRune")
}
r.i = r.prevRune
r.prevRune = -1
return nil
}
// Seek implements the io.Seeker interface.
func (r *Reader) Seek(offset int64, whence int) (int64, error) {
var abs int64
switch whence {
case 0:
abs = offset
case 1:
abs = int64(r.i) + offset
case 2:
abs = int64(len(r.s)) + offset
default:
return 0, errors.New("bytes: invalid whence")
}
if abs < 0 {
return 0, errors.New("bytes: negative position")
}
if abs >= 1<<31 {
return 0, errors.New("bytes: position out of range")
}
r.i = int(abs)
return abs, nil
}
// NewReader returns a new Reader reading from b.
func NewReader(b []byte) *Reader { return &Reader{b, 0, -1} }
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bytes_test
import (
. "bytes"
"fmt"
"io"
"os"
"testing"
)
func TestReader(t *testing.T) {
r := NewReader([]byte("0123456789"))
tests := []struct {
off int64
seek int
n int
want string
wantpos int64
seekerr string
}{
{seek: os.SEEK_SET, off: 0, n: 20, want: "0123456789"},
{seek: os.SEEK_SET, off: 1, n: 1, want: "1"},
{seek: os.SEEK_CUR, off: 1, wantpos: 3, n: 2, want: "34"},
{seek: os.SEEK_SET, off: -1, seekerr: "bytes: negative position"},
{seek: os.SEEK_SET, off: 1<<31 - 1},
{seek: os.SEEK_CUR, off: 1, seekerr: "bytes: position out of range"},
{seek: os.SEEK_SET, n: 5, want: "01234"},
{seek: os.SEEK_CUR, n: 5, want: "56789"},
{seek: os.SEEK_END, off: -1, n: 1, wantpos: 9, want: "9"},
}
for i, tt := range tests {
pos, err := r.Seek(tt.off, tt.seek)
if err == nil && tt.seekerr != "" {
t.Errorf("%d. want seek error %q", i, tt.seekerr)
continue
}
if err != nil && err.Error() != tt.seekerr {
t.Errorf("%d. seek error = %q; want %q", i, err.Error(), tt.seekerr)
continue
}
if tt.wantpos != 0 && tt.wantpos != pos {
t.Errorf("%d. pos = %d, want %d", i, pos, tt.wantpos)
}
buf := make([]byte, tt.n)
n, err := r.Read(buf)
if err != nil {
t.Errorf("%d. read = %v", i, err)
continue
}
got := string(buf[:n])
if got != tt.want {
t.Errorf("%d. got %q; want %q", i, got, tt.want)
}
}
}
func TestReaderAt(t *testing.T) {
r := NewReader([]byte("0123456789"))
tests := []struct {
off int64
n int
want string
wanterr interface{}
}{
{0, 10, "0123456789", nil},
{1, 10, "123456789", io.EOF},
{1, 9, "123456789", nil},
{11, 10, "", io.EOF},
{0, 0, "", nil},
{-1, 0, "", "bytes: invalid offset"},
}
for i, tt := range tests {
b := make([]byte, tt.n)
rn, err := r.ReadAt(b, tt.off)
got := string(b[:rn])
if got != tt.want {
t.Errorf("%d. got %q; want %q", i, got, tt.want)
}
if fmt.Sprintf("%v", err) != fmt.Sprintf("%v", tt.wanterr) {
t.Errorf("%d. got error = %v; want %v", i, err, tt.wanterr)
}
}
}
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package flate package flate
import ( import (
"fmt"
"io" "io"
"math" "math"
) )
...@@ -390,7 +391,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) { ...@@ -390,7 +391,7 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
d.fill = (*compressor).fillDeflate d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate d.step = (*compressor).deflate
default: default:
return WrongValueError{"level", 0, 9, int32(level)} return fmt.Errorf("flate: invalid compression level %d: want value in range [-1, 9]", level)
} }
return nil return nil
} }
...@@ -408,17 +409,22 @@ func (d *compressor) close() error { ...@@ -408,17 +409,22 @@ func (d *compressor) close() error {
return d.w.err return d.w.err
} }
// NewWriter returns a new Writer compressing // NewWriter returns a new Writer compressing data at the given level.
// data at the given level. Following zlib, levels // Following zlib, levels range from 1 (BestSpeed) to 9 (BestCompression);
// range from 1 (BestSpeed) to 9 (BestCompression); // higher levels typically run slower but compress more. Level 0
// higher levels typically run slower but compress more. // (NoCompression) does not attempt any compression; it only adds the
// Level 0 (NoCompression) does not attempt any // necessary DEFLATE framing. Level -1 (DefaultCompression) uses the default
// compression; it only adds the necessary DEFLATE framing. // compression level.
func NewWriter(w io.Writer, level int) *Writer { //
// If level is in the range [-1, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) {
const logWindowSize = logMaxOffsetSize const logWindowSize = logMaxOffsetSize
var dw Writer var dw Writer
dw.d.init(w, level) if err := dw.d.init(w, level); err != nil {
return &dw return nil, err
}
return &dw, nil
} }
// NewWriterDict is like NewWriter but initializes the new // NewWriterDict is like NewWriter but initializes the new
...@@ -427,13 +433,16 @@ func NewWriter(w io.Writer, level int) *Writer { ...@@ -427,13 +433,16 @@ func NewWriter(w io.Writer, level int) *Writer {
// any compressed output. The compressed data written to w // any compressed output. The compressed data written to w
// can only be decompressed by a Reader initialized with the // can only be decompressed by a Reader initialized with the
// same dictionary. // same dictionary.
func NewWriterDict(w io.Writer, level int, dict []byte) *Writer { func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
dw := &dictWriter{w, false} dw := &dictWriter{w, false}
zw := NewWriter(dw, level) zw, err := NewWriter(dw, level)
if err != nil {
return nil, err
}
zw.Write(dict) zw.Write(dict)
zw.Flush() zw.Flush()
dw.enabled = true dw.enabled = true
return zw return zw, err
} }
type dictWriter struct { type dictWriter struct {
......
...@@ -81,7 +81,11 @@ func largeDataChunk() []byte { ...@@ -81,7 +81,11 @@ func largeDataChunk() []byte {
func TestDeflate(t *testing.T) { func TestDeflate(t *testing.T) {
for _, h := range deflateTests { for _, h := range deflateTests {
var buf bytes.Buffer var buf bytes.Buffer
w := NewWriter(&buf, h.level) w, err := NewWriter(&buf, h.level)
if err != nil {
t.Errorf("NewWriter: %v", err)
continue
}
w.Write(h.in) w.Write(h.in)
w.Close() w.Close()
if !bytes.Equal(buf.Bytes(), h.out) { if !bytes.Equal(buf.Bytes(), h.out) {
...@@ -151,7 +155,11 @@ func testSync(t *testing.T, level int, input []byte, name string) { ...@@ -151,7 +155,11 @@ func testSync(t *testing.T, level int, input []byte, name string) {
buf := newSyncBuffer() buf := newSyncBuffer()
buf1 := new(bytes.Buffer) buf1 := new(bytes.Buffer)
buf.WriteMode() buf.WriteMode()
w := NewWriter(io.MultiWriter(buf, buf1), level) w, err := NewWriter(io.MultiWriter(buf, buf1), level)
if err != nil {
t.Errorf("NewWriter: %v", err)
return
}
r := NewReader(buf) r := NewReader(buf)
// Write half the input and read back. // Write half the input and read back.
...@@ -213,7 +221,7 @@ func testSync(t *testing.T, level int, input []byte, name string) { ...@@ -213,7 +221,7 @@ func testSync(t *testing.T, level int, input []byte, name string) {
// stream should work for ordinary reader too // stream should work for ordinary reader too
r = NewReader(buf1) r = NewReader(buf1)
out, err := ioutil.ReadAll(r) out, err = ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("testSync: read: %s", err) t.Errorf("testSync: read: %s", err)
return return
...@@ -224,31 +232,31 @@ func testSync(t *testing.T, level int, input []byte, name string) { ...@@ -224,31 +232,31 @@ func testSync(t *testing.T, level int, input []byte, name string) {
} }
} }
func testToFromWithLevel(t *testing.T, level int, input []byte, name string) error { func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) {
return testToFromWithLevelAndLimit(t, level, input, name, -1)
}
func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) error {
var buffer bytes.Buffer var buffer bytes.Buffer
w := NewWriter(&buffer, level) w, err := NewWriter(&buffer, level)
if err != nil {
t.Errorf("NewWriter: %v", err)
return
}
w.Write(input) w.Write(input)
w.Close() w.Close()
if limit > 0 && buffer.Len() > limit { if limit > 0 && buffer.Len() > limit {
t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit) t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit)
return
} }
r := NewReader(&buffer) r := NewReader(&buffer)
out, err := ioutil.ReadAll(r) out, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("read: %s", err) t.Errorf("read: %s", err)
return err return
} }
r.Close() r.Close()
if !bytes.Equal(input, out) { if !bytes.Equal(input, out) {
t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name) t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name)
return
} }
testSync(t, level, input, name) testSync(t, level, input, name)
return nil
} }
func testToFromWithLimit(t *testing.T, input []byte, name string, limit [10]int) { func testToFromWithLimit(t *testing.T, input []byte, name string, limit [10]int) {
...@@ -257,13 +265,9 @@ func testToFromWithLimit(t *testing.T, input []byte, name string, limit [10]int) ...@@ -257,13 +265,9 @@ func testToFromWithLimit(t *testing.T, input []byte, name string, limit [10]int)
} }
} }
func testToFrom(t *testing.T, input []byte, name string) {
testToFromWithLimit(t, input, name, [10]int{})
}
func TestDeflateInflate(t *testing.T) { func TestDeflateInflate(t *testing.T) {
for i, h := range deflateInflateTests { for i, h := range deflateInflateTests {
testToFrom(t, h.in, fmt.Sprintf("#%d", i)) testToFromWithLimit(t, h.in, fmt.Sprintf("#%d", i), [10]int{})
} }
} }
...@@ -311,7 +315,10 @@ func TestReaderDict(t *testing.T) { ...@@ -311,7 +315,10 @@ func TestReaderDict(t *testing.T) {
text = "hello again world" text = "hello again world"
) )
var b bytes.Buffer var b bytes.Buffer
w := NewWriter(&b, 5) w, err := NewWriter(&b, 5)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
w.Write([]byte(dict)) w.Write([]byte(dict))
w.Flush() w.Flush()
b.Reset() b.Reset()
...@@ -334,7 +341,10 @@ func TestWriterDict(t *testing.T) { ...@@ -334,7 +341,10 @@ func TestWriterDict(t *testing.T) {
text = "hello again world" text = "hello again world"
) )
var b bytes.Buffer var b bytes.Buffer
w := NewWriter(&b, 5) w, err := NewWriter(&b, 5)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
w.Write([]byte(dict)) w.Write([]byte(dict))
w.Flush() w.Flush()
b.Reset() b.Reset()
...@@ -342,7 +352,7 @@ func TestWriterDict(t *testing.T) { ...@@ -342,7 +352,7 @@ func TestWriterDict(t *testing.T) {
w.Close() w.Close()
var b1 bytes.Buffer var b1 bytes.Buffer
w = NewWriterDict(&b1, 5, []byte(dict)) w, _ = NewWriterDict(&b1, 5, []byte(dict))
w.Write([]byte(text)) w.Write([]byte(text))
w.Close() w.Close()
...@@ -353,7 +363,10 @@ func TestWriterDict(t *testing.T) { ...@@ -353,7 +363,10 @@ func TestWriterDict(t *testing.T) {
// See http://code.google.com/p/go/issues/detail?id=2508 // See http://code.google.com/p/go/issues/detail?id=2508
func TestRegression2508(t *testing.T) { func TestRegression2508(t *testing.T) {
w := NewWriter(ioutil.Discard, 1) w, err := NewWriter(ioutil.Discard, 1)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
buf := make([]byte, 1024) buf := make([]byte, 1024)
for i := 0; i < 131072; i++ { for i := 0; i < 131072; i++ {
if _, err := w.Write(buf); err != nil { if _, err := w.Write(buf); err != nil {
......
...@@ -7,7 +7,6 @@ package flate ...@@ -7,7 +7,6 @@ package flate
import ( import (
"io" "io"
"math" "math"
"strconv"
) )
const ( const (
...@@ -85,13 +84,6 @@ type huffmanBitWriter struct { ...@@ -85,13 +84,6 @@ type huffmanBitWriter struct {
err error err error
} }
type WrongValueError struct {
name string
from int32
to int32
value int32
}
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{ return &huffmanBitWriter{
w: w, w: w,
...@@ -105,11 +97,6 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { ...@@ -105,11 +97,6 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
} }
} }
func (err WrongValueError) Error() string {
return "huffmanBitWriter: " + err.name + " should belong to [" + strconv.FormatInt(int64(err.from), 10) + ";" +
strconv.FormatInt(int64(err.to), 10) + "] but actual value is " + strconv.FormatInt(int64(err.value), 10)
}
func (w *huffmanBitWriter) flushBits() { func (w *huffmanBitWriter) flushBits() {
if w.err != nil { if w.err != nil {
w.nbits = 0 w.nbits = 0
......
...@@ -16,9 +16,6 @@ import ( ...@@ -16,9 +16,6 @@ import (
"time" "time"
) )
// BUG(nigeltao): Comments and Names don't properly map UTF-8 character codes outside of
// the 0x00-0x7f range to ISO 8859-1 (Latin-1).
const ( const (
gzipID1 = 0x1f gzipID1 = 0x1f
gzipID2 = 0x8b gzipID2 = 0x8b
...@@ -37,11 +34,15 @@ func makeReader(r io.Reader) flate.Reader { ...@@ -37,11 +34,15 @@ func makeReader(r io.Reader) flate.Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
var ErrHeader = errors.New("invalid gzip header") var (
var ErrChecksum = errors.New("gzip checksum error") // ErrChecksum is returned when reading GZIP data that has an invalid checksum.
ErrChecksum = errors.New("gzip: invalid checksum")
// ErrHeader is returned when reading GZIP data that has an invalid header.
ErrHeader = errors.New("gzip: invalid header")
)
// The gzip file stores a header giving metadata about the compressed file. // The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Compressor and Decompressor structs. // That header is exposed as the fields of the Writer and Reader structs.
type Header struct { type Header struct {
Comment string // comment Comment string // comment
Extra []byte // "extra data" Extra []byte // "extra data"
...@@ -50,21 +51,21 @@ type Header struct { ...@@ -50,21 +51,21 @@ type Header struct {
OS byte // operating system type OS byte // operating system type
} }
// An Decompressor is an io.Reader that can be read to retrieve // A Reader is an io.Reader that can be read to retrieve
// uncompressed data from a gzip-format compressed file. // uncompressed data from a gzip-format compressed file.
// //
// In general, a gzip file can be a concatenation of gzip files, // In general, a gzip file can be a concatenation of gzip files,
// each with its own header. Reads from the Decompressor // each with its own header. Reads from the Reader
// return the concatenation of the uncompressed data of each. // return the concatenation of the uncompressed data of each.
// Only the first header is recorded in the Decompressor fields. // Only the first header is recorded in the Reader fields.
// //
// Gzip files store a length and checksum of the uncompressed data. // Gzip files store a length and checksum of the uncompressed data.
// The Decompressor will return a ErrChecksum when Read // The Reader will return a ErrChecksum when Read
// reaches the end of the uncompressed data if it does not // reaches the end of the uncompressed data if it does not
// have the expected length or checksum. Clients should treat data // have the expected length or checksum. Clients should treat data
// returned by Read as tentative until they receive the successful // returned by Read as tentative until they receive the io.EOF
// (zero length, nil error) Read marking the end of the data. // marking the end of the data.
type Decompressor struct { type Reader struct {
Header Header
r flate.Reader r flate.Reader
decompressor io.ReadCloser decompressor io.ReadCloser
...@@ -75,15 +76,14 @@ type Decompressor struct { ...@@ -75,15 +76,14 @@ type Decompressor struct {
err error err error
} }
// NewReader creates a new Decompressor reading the given reader. // NewReader creates a new Reader reading the given reader.
// The implementation buffers input and may read more data than necessary from r. // The implementation buffers input and may read more data than necessary from r.
// It is the caller's responsibility to call Close on the Decompressor when done. // It is the caller's responsibility to call Close on the Reader when done.
func NewReader(r io.Reader) (*Decompressor, error) { func NewReader(r io.Reader) (*Reader, error) {
z := new(Decompressor) z := new(Reader)
z.r = makeReader(r) z.r = makeReader(r)
z.digest = crc32.NewIEEE() z.digest = crc32.NewIEEE()
if err := z.readHeader(true); err != nil { if err := z.readHeader(true); err != nil {
z.err = err
return nil, err return nil, err
} }
return z, nil return z, nil
...@@ -94,7 +94,7 @@ func get4(p []byte) uint32 { ...@@ -94,7 +94,7 @@ func get4(p []byte) uint32 {
return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
} }
func (z *Decompressor) readString() (string, error) { func (z *Reader) readString() (string, error) {
var err error var err error
needconv := false needconv := false
for i := 0; ; i++ { for i := 0; ; i++ {
...@@ -123,7 +123,7 @@ func (z *Decompressor) readString() (string, error) { ...@@ -123,7 +123,7 @@ func (z *Decompressor) readString() (string, error) {
panic("not reached") panic("not reached")
} }
func (z *Decompressor) read2() (uint32, error) { func (z *Reader) read2() (uint32, error) {
_, err := io.ReadFull(z.r, z.buf[0:2]) _, err := io.ReadFull(z.r, z.buf[0:2])
if err != nil { if err != nil {
return 0, err return 0, err
...@@ -131,7 +131,7 @@ func (z *Decompressor) read2() (uint32, error) { ...@@ -131,7 +131,7 @@ func (z *Decompressor) read2() (uint32, error) {
return uint32(z.buf[0]) | uint32(z.buf[1])<<8, nil return uint32(z.buf[0]) | uint32(z.buf[1])<<8, nil
} }
func (z *Decompressor) readHeader(save bool) error { func (z *Reader) readHeader(save bool) error {
_, err := io.ReadFull(z.r, z.buf[0:10]) _, err := io.ReadFull(z.r, z.buf[0:10])
if err != nil { if err != nil {
return err return err
...@@ -197,7 +197,7 @@ func (z *Decompressor) readHeader(save bool) error { ...@@ -197,7 +197,7 @@ func (z *Decompressor) readHeader(save bool) error {
return nil return nil
} }
func (z *Decompressor) Read(p []byte) (n int, err error) { func (z *Reader) Read(p []byte) (n int, err error) {
if z.err != nil { if z.err != nil {
return 0, z.err return 0, z.err
} }
...@@ -237,5 +237,5 @@ func (z *Decompressor) Read(p []byte) (n int, err error) { ...@@ -237,5 +237,5 @@ func (z *Decompressor) Read(p []byte) (n int, err error) {
return z.Read(p) return z.Read(p)
} }
// Calling Close does not close the wrapped io.Reader originally passed to NewReader. // Close closes the Reader. It does not close the underlying io.Reader.
func (z *Decompressor) Close() error { return z.decompressor.Close() } func (z *Reader) Close() error { return z.decompressor.Close() }
...@@ -7,6 +7,7 @@ package gzip ...@@ -7,6 +7,7 @@ package gzip
import ( import (
"compress/flate" "compress/flate"
"errors" "errors"
"fmt"
"hash" "hash"
"hash/crc32" "hash/crc32"
"io" "io"
...@@ -21,9 +22,9 @@ const ( ...@@ -21,9 +22,9 @@ const (
DefaultCompression = flate.DefaultCompression DefaultCompression = flate.DefaultCompression
) )
// A Compressor is an io.WriteCloser that satisfies writes by compressing data written // A Writer is an io.WriteCloser that satisfies writes by compressing data written
// to its wrapped io.Writer. // to its wrapped io.Writer.
type Compressor struct { type Writer struct {
Header Header
w io.Writer w io.Writer
level int level int
...@@ -35,25 +36,40 @@ type Compressor struct { ...@@ -35,25 +36,40 @@ type Compressor struct {
err error err error
} }
// NewWriter calls NewWriterLevel with the default compression level. // NewWriter creates a new Writer that satisfies writes by compressing data
func NewWriter(w io.Writer) (*Compressor, error) { // written to w.
return NewWriterLevel(w, DefaultCompression) //
// It is the caller's responsibility to call Close on the WriteCloser when done.
// Writes may be buffered and not flushed until Close.
//
// Callers that wish to set the fields in Writer.Header must do so before
// the first call to Write or Close. The Comment and Name header fields are
// UTF-8 strings in Go, but the underlying format requires NUL-terminated ISO
// 8859-1 (Latin-1). NUL or non-Latin-1 runes in those strings will lead to an
// error on Write.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevel(w, DefaultCompression)
return z
} }
// NewWriterLevel creates a new Compressor writing to the given writer. // NewWriterLevel is like NewWriter but specifies the compression level instead
// Writes may be buffered and not flushed until Close. // of assuming DefaultCompression.
// Callers that wish to set the fields in Compressor.Header must //
// do so before the first call to Write or Close. // The compression level can be DefaultCompression, NoCompression, or any
// It is the caller's responsibility to call Close on the WriteCloser when done. // integer value between BestSpeed and BestCompression inclusive. The error
// level is the compression level, which can be DefaultCompression, NoCompression, // returned will be nil if the level is valid.
// or any integer value between BestSpeed and BestCompression (inclusive). func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
func NewWriterLevel(w io.Writer, level int) (*Compressor, error) { if level < DefaultCompression || level > BestCompression {
z := new(Compressor) return nil, fmt.Errorf("gzip: invalid compression level: %d", level)
z.OS = 255 // unknown }
z.w = w return &Writer{
z.level = level Header: Header{
z.digest = crc32.NewIEEE() OS: 255, // unknown
return z, nil },
w: w,
level: level,
digest: crc32.NewIEEE(),
}, nil
} }
// GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950). // GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950).
...@@ -70,7 +86,7 @@ func put4(p []byte, v uint32) { ...@@ -70,7 +86,7 @@ func put4(p []byte, v uint32) {
} }
// writeBytes writes a length-prefixed byte slice to z.w. // writeBytes writes a length-prefixed byte slice to z.w.
func (z *Compressor) writeBytes(b []byte) error { func (z *Writer) writeBytes(b []byte) error {
if len(b) > 0xffff { if len(b) > 0xffff {
return errors.New("gzip.Write: Extra data is too large") return errors.New("gzip.Write: Extra data is too large")
} }
...@@ -83,10 +99,10 @@ func (z *Compressor) writeBytes(b []byte) error { ...@@ -83,10 +99,10 @@ func (z *Compressor) writeBytes(b []byte) error {
return err return err
} }
// writeString writes a string (in ISO 8859-1 (Latin-1) format) to z.w. // writeString writes a UTF-8 string s in GZIP's format to z.w.
func (z *Compressor) writeString(s string) error { // GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1).
// GZIP (RFC 1952) specifies that strings are NUL-terminated ISO 8859-1 (Latin-1). func (z *Writer) writeString(s string) (err error) {
var err error // GZIP stores Latin-1 strings; error if non-Latin-1; convert if non-ASCII.
needconv := false needconv := false
for _, v := range s { for _, v := range s {
if v == 0 || v > 0xff { if v == 0 || v > 0xff {
...@@ -114,7 +130,9 @@ func (z *Compressor) writeString(s string) error { ...@@ -114,7 +130,9 @@ func (z *Compressor) writeString(s string) error {
return err return err
} }
func (z *Compressor) Write(p []byte) (int, error) { // Write writes a compressed form of p to the underlying io.Writer. The
// compressed bytes are not necessarily flushed until the Writer is closed.
func (z *Writer) Write(p []byte) (int, error) {
if z.err != nil { if z.err != nil {
return 0, z.err return 0, z.err
} }
...@@ -165,7 +183,7 @@ func (z *Compressor) Write(p []byte) (int, error) { ...@@ -165,7 +183,7 @@ func (z *Compressor) Write(p []byte) (int, error) {
return n, z.err return n, z.err
} }
} }
z.compressor = flate.NewWriter(z.w, z.level) z.compressor, _ = flate.NewWriter(z.w, z.level)
} }
z.size += uint32(len(p)) z.size += uint32(len(p))
z.digest.Write(p) z.digest.Write(p)
...@@ -173,8 +191,8 @@ func (z *Compressor) Write(p []byte) (int, error) { ...@@ -173,8 +191,8 @@ func (z *Compressor) Write(p []byte) (int, error) {
return n, z.err return n, z.err
} }
// Calling Close does not close the wrapped io.Writer originally passed to NewWriter. // Close closes the Writer. It does not close the underlying io.Writer.
func (z *Compressor) Close() error { func (z *Writer) Close() error {
if z.err != nil { if z.err != nil {
return z.err return z.err
} }
......
...@@ -7,108 +7,153 @@ package gzip ...@@ -7,108 +7,153 @@ package gzip
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"io/ioutil" "io/ioutil"
"testing" "testing"
"time" "time"
) )
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the // TestEmpty tests that an empty payload still forms a valid GZIP stream.
// writer end and cfunc at the reader end. func TestEmpty(t *testing.T) {
func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) { buf := new(bytes.Buffer)
piper, pipew := io.Pipe()
defer piper.Close() if err := NewWriter(buf).Close(); err != nil {
go func() { t.Fatalf("Writer.Close: %v", err)
defer pipew.Close() }
compressor, err := NewWriter(pipew)
if err != nil { r, err := NewReader(buf)
t.Fatalf("%v", err) if err != nil {
} t.Fatalf("NewReader: %v", err)
defer compressor.Close() }
dfunc(compressor) b, err := ioutil.ReadAll(r)
}()
decompressor, err := NewReader(piper)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("ReadAll: %v", err)
}
if len(b) != 0 {
t.Fatalf("got %d bytes, want 0", len(b))
}
if err := r.Close(); err != nil {
t.Fatalf("Reader.Close: %v", err)
} }
defer decompressor.Close()
cfunc(decompressor)
} }
// Tests that an empty payload still forms a valid GZIP stream. // TestRoundTrip tests that gzipping and then gunzipping is the identity
func TestEmpty(t *testing.T) { // function.
pipe(t, func TestRoundTrip(t *testing.T) {
func(compressor *Compressor) {}, buf := new(bytes.Buffer)
func(decompressor *Decompressor) {
b, err := ioutil.ReadAll(decompressor) w := NewWriter(buf)
if err != nil { w.Comment = "comment"
t.Fatalf("%v", err) w.Extra = []byte("extra")
} w.ModTime = time.Unix(1e8, 0)
if len(b) != 0 { w.Name = "name"
t.Fatalf("did not read an empty slice") if _, err := w.Write([]byte("payload")); err != nil {
} t.Fatalf("Write: %v", err)
}) }
} if err := w.Close(); err != nil {
t.Fatalf("Writer.Close: %v", err)
}
// Tests that gzipping and then gunzipping is the identity function. r, err := NewReader(buf)
func TestWriter(t *testing.T) { if err != nil {
pipe(t, t.Fatalf("NewReader: %v", err)
func(compressor *Compressor) { }
compressor.Comment = "Äußerung" b, err := ioutil.ReadAll(r)
//compressor.Comment = "comment" if err != nil {
compressor.Extra = []byte("extra") t.Fatalf("ReadAll: %v", err)
compressor.ModTime = time.Unix(1e8, 0) }
compressor.Name = "name" if string(b) != "payload" {
_, err := compressor.Write([]byte("payload")) t.Fatalf("payload is %q, want %q", string(b), "payload")
if err != nil { }
t.Fatalf("%v", err) if r.Comment != "comment" {
} t.Fatalf("comment is %q, want %q", r.Comment, "comment")
}, }
func(decompressor *Decompressor) { if string(r.Extra) != "extra" {
b, err := ioutil.ReadAll(decompressor) t.Fatalf("extra is %q, want %q", r.Extra, "extra")
if err != nil { }
t.Fatalf("%v", err) if r.ModTime.Unix() != 1e8 {
} t.Fatalf("mtime is %d, want %d", r.ModTime.Unix(), uint32(1e8))
if string(b) != "payload" { }
t.Fatalf("payload is %q, want %q", string(b), "payload") if r.Name != "name" {
} t.Fatalf("name is %q, want %q", r.Name, "name")
if decompressor.Comment != "Äußerung" { }
t.Fatalf("comment is %q, want %q", decompressor.Comment, "Äußerung") if err := r.Close(); err != nil {
} t.Fatalf("Reader.Close: %v", err)
if string(decompressor.Extra) != "extra" { }
t.Fatalf("extra is %q, want %q", decompressor.Extra, "extra")
}
if decompressor.ModTime.Unix() != 1e8 {
t.Fatalf("mtime is %d, want %d", decompressor.ModTime.Unix(), uint32(1e8))
}
if decompressor.Name != "name" {
t.Fatalf("name is %q, want %q", decompressor.Name, "name")
}
})
} }
// TestLatin1 tests the internal functions for converting to and from Latin-1.
func TestLatin1(t *testing.T) { func TestLatin1(t *testing.T) {
latin1 := []byte{0xc4, 'u', 0xdf, 'e', 'r', 'u', 'n', 'g', 0} latin1 := []byte{0xc4, 'u', 0xdf, 'e', 'r', 'u', 'n', 'g', 0}
utf8 := "Äußerung" utf8 := "Äußerung"
z := Decompressor{r: bufio.NewReader(bytes.NewBuffer(latin1))} z := Reader{r: bufio.NewReader(bytes.NewBuffer(latin1))}
s, err := z.readString() s, err := z.readString()
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("readString: %v", err)
} }
if s != utf8 { if s != utf8 {
t.Fatalf("string is %q, want %q", s, utf8) t.Fatalf("read latin-1: got %q, want %q", s, utf8)
} }
buf := bytes.NewBuffer(make([]byte, 0, len(latin1))) buf := bytes.NewBuffer(make([]byte, 0, len(latin1)))
c := Compressor{w: buf} c := Writer{w: buf}
if err = c.writeString(utf8); err != nil { if err = c.writeString(utf8); err != nil {
t.Fatalf("%v", err) t.Fatalf("writeString: %v", err)
} }
s = buf.String() s = buf.String()
if s != string(latin1) { if s != string(latin1) {
t.Fatalf("string is %v, want %v", s, latin1) t.Fatalf("write utf-8: got %q, want %q", s, string(latin1))
}
}
// TestLatin1RoundTrip tests that metadata that is representable in Latin-1
// survives a round trip.
func TestLatin1RoundTrip(t *testing.T) {
testCases := []struct {
name string
ok bool
}{
{"", true},
{"ASCII is OK", true},
{"unless it contains a NUL\x00", false},
{"no matter where \x00 occurs", false},
{"\x00\x00\x00", false},
{"Látin-1 also passes (U+00E1)", true},
{"but LĀtin Extended-A (U+0100) does not", false},
{"neither does 日本語", false},
{"invalid UTF-8 also \xffails", false},
{"\x00 as does Látin-1 with NUL", false},
}
for _, tc := range testCases {
buf := new(bytes.Buffer)
w := NewWriter(buf)
w.Name = tc.name
err := w.Close()
if (err == nil) != tc.ok {
t.Errorf("Writer.Close: name = %q, err = %v", tc.name, err)
continue
}
if !tc.ok {
continue
}
r, err := NewReader(buf)
if err != nil {
t.Errorf("NewReader: %v", err)
continue
}
_, err = ioutil.ReadAll(r)
if err != nil {
t.Errorf("ReadAll: %v", err)
continue
}
if r.Name != tc.name {
t.Errorf("name is %q, want %q", r.Name, tc.name)
continue
}
if err := r.Close(); err != nil {
t.Errorf("Reader.Close: %v", err)
continue
}
} }
//if s, err = buf.ReadString(0); err != nil {
//t.Fatalf("%v", err)
//}
} }
...@@ -34,9 +34,14 @@ import ( ...@@ -34,9 +34,14 @@ import (
const zlibDeflate = 8 const zlibDeflate = 8
var ErrChecksum = errors.New("zlib checksum error") var (
var ErrHeader = errors.New("invalid zlib header") // ErrChecksum is returned when reading ZLIB data that has an invalid checksum.
var ErrDictionary = errors.New("invalid zlib dictionary") ErrChecksum = errors.New("zlib: invalid checksum")
// ErrDictionary is returned when reading ZLIB data that has an invalid dictionary.
ErrDictionary = errors.New("zlib: invalid dictionary")
// ErrHeader is returned when reading ZLIB data that has an invalid header.
ErrHeader = errors.New("zlib: invalid header")
)
type reader struct { type reader struct {
r flate.Reader r flate.Reader
......
...@@ -6,7 +6,7 @@ package zlib ...@@ -6,7 +6,7 @@ package zlib
import ( import (
"compress/flate" "compress/flate"
"errors" "fmt"
"hash" "hash"
"hash/adler32" "hash/adler32"
"io" "io"
...@@ -24,30 +24,55 @@ const ( ...@@ -24,30 +24,55 @@ const (
// A Writer takes data written to it and writes the compressed // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
w io.Writer w io.Writer
compressor *flate.Writer level int
digest hash.Hash32 dict []byte
err error compressor *flate.Writer
scratch [4]byte digest hash.Hash32
err error
scratch [4]byte
wroteHeader bool
} }
// NewWriter calls NewWriterLevel with the default compression level. // NewWriter creates a new Writer that satisfies writes by compressing data
func NewWriter(w io.Writer) (*Writer, error) { // written to w.
return NewWriterLevel(w, DefaultCompression) //
// It is the caller's responsibility to call Close on the WriteCloser when done.
// Writes may be buffered and not flushed until Close.
func NewWriter(w io.Writer) *Writer {
z, _ := NewWriterLevelDict(w, DefaultCompression, nil)
return z
} }
// NewWriterLevel calls NewWriterDict with no dictionary. // NewWriterLevel is like NewWriter but specifies the compression level instead
// of assuming DefaultCompression.
//
// The compression level can be DefaultCompression, NoCompression, or any
// integer value between BestSpeed and BestCompression inclusive. The error
// returned will be nil if the level is valid.
func NewWriterLevel(w io.Writer, level int) (*Writer, error) { func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
return NewWriterDict(w, level, nil) return NewWriterLevelDict(w, level, nil)
} }
// NewWriterDict creates a new io.WriteCloser that satisfies writes by compressing data written to w. // NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
// It is the caller's responsibility to call Close on the WriteCloser when done. // compress with.
// level is the compression level, which can be DefaultCompression, NoCompression, //
// or any integer value between BestSpeed and BestCompression (inclusive). // The dictionary may be nil. If not, its contents should not be modified until
// dict is the preset dictionary to compress with, or nil to use no dictionary. // the Writer is closed.
func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) {
z := new(Writer) if level < DefaultCompression || level > BestCompression {
return nil, fmt.Errorf("zlib: invalid compression level: %d", level)
}
return &Writer{
w: w,
level: level,
dict: dict,
}, nil
}
// writeHeader writes the ZLIB header.
func (z *Writer) writeHeader() (err error) {
z.wroteHeader = true
// ZLIB has a two-byte header (as documented in RFC 1950). // ZLIB has a two-byte header (as documented in RFC 1950).
// The first four bits is the CINFO (compression info), which is 7 for the default deflate window size. // The first four bits is the CINFO (compression info), which is 7 for the default deflate window size.
// The next four bits is the CM (compression method), which is 8 for deflate. // The next four bits is the CM (compression method), which is 8 for deflate.
...@@ -56,7 +81,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { ...@@ -56,7 +81,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
// 0=fastest, 1=fast, 2=default, 3=best. // 0=fastest, 1=fast, 2=default, 3=best.
// The next bit, FDICT, is set if a dictionary is given. // The next bit, FDICT, is set if a dictionary is given.
// The final five FCHECK bits form a mod-31 checksum. // The final five FCHECK bits form a mod-31 checksum.
switch level { switch z.level {
case 0, 1: case 0, 1:
z.scratch[1] = 0 << 6 z.scratch[1] = 0 << 6
case 2, 3, 4, 5: case 2, 3, 4, 5:
...@@ -66,35 +91,41 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { ...@@ -66,35 +91,41 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
case 7, 8, 9: case 7, 8, 9:
z.scratch[1] = 3 << 6 z.scratch[1] = 3 << 6
default: default:
return nil, errors.New("level out of range") panic("unreachable")
} }
if dict != nil { if z.dict != nil {
z.scratch[1] |= 1 << 5 z.scratch[1] |= 1 << 5
} }
z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31) z.scratch[1] += uint8(31 - (uint16(z.scratch[0])<<8+uint16(z.scratch[1]))%31)
_, err := w.Write(z.scratch[0:2]) if _, err = z.w.Write(z.scratch[0:2]); err != nil {
if err != nil { return err
return nil, err
} }
if dict != nil { if z.dict != nil {
// The next four bytes are the Adler-32 checksum of the dictionary. // The next four bytes are the Adler-32 checksum of the dictionary.
checksum := adler32.Checksum(dict) checksum := adler32.Checksum(z.dict)
z.scratch[0] = uint8(checksum >> 24) z.scratch[0] = uint8(checksum >> 24)
z.scratch[1] = uint8(checksum >> 16) z.scratch[1] = uint8(checksum >> 16)
z.scratch[2] = uint8(checksum >> 8) z.scratch[2] = uint8(checksum >> 8)
z.scratch[3] = uint8(checksum >> 0) z.scratch[3] = uint8(checksum >> 0)
_, err = w.Write(z.scratch[0:4]) if _, err = z.w.Write(z.scratch[0:4]); err != nil {
if err != nil { return err
return nil, err
} }
} }
z.w = w z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict)
z.compressor = flate.NewWriterDict(w, level, dict) if err != nil {
return err
}
z.digest = adler32.New() z.digest = adler32.New()
return z, nil return nil
} }
// Write writes a compressed form of p to the underlying io.Writer. The
// compressed bytes are not necessarily flushed until the Writer is closed or
// explicitly flushed.
func (z *Writer) Write(p []byte) (n int, err error) { func (z *Writer) Write(p []byte) (n int, err error) {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil { if z.err != nil {
return 0, z.err return 0, z.err
} }
...@@ -110,8 +141,11 @@ func (z *Writer) Write(p []byte) (n int, err error) { ...@@ -110,8 +141,11 @@ func (z *Writer) Write(p []byte) (n int, err error) {
return return
} }
// Flush flushes the underlying compressor. // Flush flushes the Writer to its underlying io.Writer.
func (z *Writer) Flush() error { func (z *Writer) Flush() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil { if z.err != nil {
return z.err return z.err
} }
...@@ -121,6 +155,9 @@ func (z *Writer) Flush() error { ...@@ -121,6 +155,9 @@ func (z *Writer) Flush() error {
// Calling Close does not close the wrapped io.Writer originally passed to NewWriter. // Calling Close does not close the wrapped io.Writer originally passed to NewWriter.
func (z *Writer) Close() error { func (z *Writer) Close() error {
if !z.wroteHeader {
z.err = z.writeHeader()
}
if z.err != nil { if z.err != nil {
return z.err return z.err
} }
......
...@@ -52,7 +52,7 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) { ...@@ -52,7 +52,7 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
defer piper.Close() defer piper.Close()
go func() { go func() {
defer pipew.Close() defer pipew.Close()
zlibw, err := NewWriterDict(pipew, level, dict) zlibw, err := NewWriterLevelDict(pipew, level, dict)
if err != nil { if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
...@@ -125,9 +125,9 @@ func TestWriterDict(t *testing.T) { ...@@ -125,9 +125,9 @@ func TestWriterDict(t *testing.T) {
func TestWriterDictIsUsed(t *testing.T) { func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
var buf bytes.Buffer var buf bytes.Buffer
compressor, err := NewWriterDict(&buf, BestCompression, input) compressor, err := NewWriterLevelDict(&buf, BestCompression, input)
if err != nil { if err != nil {
t.Errorf("error in NewWriterDict: %s", err) t.Errorf("error in NewWriterLevelDict: %s", err)
return return
} }
compressor.Write(input) compressor.Write(input)
......
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This example demonstrates a priority queue built using the heap interface.
package heap_test
import (
"container/heap"
"fmt"
)
// An Item is something we manage in a priority queue.
type Item struct {
value string // The value of the item; arbitrary.
priority int // The priority of the item in the queue.
// The index is needed by changePriority and is maintained by the heap.Interface methods.
index int // The index of the item in the heap.
}
// A PriorityQueue implements heap.Interface and holds Items.
type PriorityQueue []*Item
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
// We want Pop to give us the highest, not lowest, priority so we use greater than here.
return pq[i].priority > pq[j].priority
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *PriorityQueue) Push(x interface{}) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
// To simplify indexing expressions in these methods, we save a copy of the
// slice object. We could instead write (*pq)[i].
a := *pq
n := len(a)
a = a[0 : n+1]
item := x.(*Item)
item.index = n
a[n] = item
*pq = a
}
func (pq *PriorityQueue) Pop() interface{} {
a := *pq
n := len(a)
item := a[n-1]
item.index = -1 // for safety
*pq = a[0 : n-1]
return item
}
// 99:seven 88:five 77:zero 66:nine 55:three 44:two 33:six 22:one 11:four 00:eight
func ExampleInterface() {
// The full code of this example, including the methods that implement
// heap.Interface, is in the file src/pkg/container/heap/example_test.go.
const nItem = 10
// Random priorities for the items (a permutation of 0..9, times 11)).
priorities := [nItem]int{
77, 22, 44, 55, 11, 88, 33, 99, 00, 66,
}
values := [nItem]string{
"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
}
// Create a priority queue and put some items in it.
pq := make(PriorityQueue, 0, nItem)
for i := 0; i < cap(pq); i++ {
item := &Item{
value: values[i],
priority: priorities[i],
}
heap.Push(&pq, item)
}
// Take the items out; should arrive in decreasing priority order.
// For example, the highest priority (99) is the seventh item, so output starts with 99:"seven".
for i := 0; i < nItem; i++ {
item := heap.Pop(&pq).(*Item)
fmt.Printf("%.2d:%s ", item.priority, item.value)
}
}
// update is not used by the example but shows how to take the top item from the queue,
// update its priority and value, and put it back.
func (pq *PriorityQueue) update(value string, priority int) {
item := heap.Pop(pq).(*Item)
item.value = value
item.priority = priority
heap.Push(pq, item)
}
// changePriority is not used by the example but shows how to change the priority of an arbitrary
// item.
func (pq *PriorityQueue) changePriority(item *Item, priority int) {
heap.Remove(pq, item.index)
item.priority = priority
heap.Push(pq, item)
}
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
// heap.Interface. A heap is a tree with the property that each node is the // heap.Interface. A heap is a tree with the property that each node is the
// highest-valued node in its subtree. // highest-valued node in its subtree.
// //
// A heap is a common way to impement a priority queue. To build a priority // A heap is a common way to implement a priority queue. To build a priority
// queue, implement the Heap interface with the (negative) priority as the // queue, implement the Heap interface with the (negative) priority as the
// ordering for the Less method, so Push adds items while Pop removes the // ordering for the Less method, so Push adds items while Pop removes the
// highest-priority item from the queue. // highest-priority item from the queue. The Examples include such an
// implementation; the file example_test.go has the complete source.
// //
package heap package heap
......
...@@ -4,13 +4,16 @@ ...@@ -4,13 +4,16 @@
package aes package aes
import "strconv" import (
"crypto/cipher"
"strconv"
)
// The AES block size in bytes. // The AES block size in bytes.
const BlockSize = 16 const BlockSize = 16
// A Cipher is an instance of AES encryption using a particular key. // A cipher is an instance of AES encryption using a particular key.
type Cipher struct { type aesCipher struct {
enc []uint32 enc []uint32
dec []uint32 dec []uint32
} }
...@@ -21,11 +24,11 @@ func (k KeySizeError) Error() string { ...@@ -21,11 +24,11 @@ func (k KeySizeError) Error() string {
return "crypto/aes: invalid key size " + strconv.Itoa(int(k)) return "crypto/aes: invalid key size " + strconv.Itoa(int(k))
} }
// NewCipher creates and returns a new Cipher. // NewCipher creates and returns a new cipher.Block.
// The key argument should be the AES key, // The key argument should be the AES key,
// either 16, 24, or 32 bytes to select // either 16, 24, or 32 bytes to select
// AES-128, AES-192, or AES-256. // AES-128, AES-192, or AES-256.
func NewCipher(key []byte) (*Cipher, error) { func NewCipher(key []byte) (cipher.Block, error) {
k := len(key) k := len(key)
switch k { switch k {
default: default:
...@@ -35,34 +38,13 @@ func NewCipher(key []byte) (*Cipher, error) { ...@@ -35,34 +38,13 @@ func NewCipher(key []byte) (*Cipher, error) {
} }
n := k + 28 n := k + 28
c := &Cipher{make([]uint32, n), make([]uint32, n)} c := &aesCipher{make([]uint32, n), make([]uint32, n)}
expandKey(key, c.enc, c.dec) expandKey(key, c.enc, c.dec)
return c, nil return c, nil
} }
// BlockSize returns the AES block size, 16 bytes. func (c *aesCipher) BlockSize() int { return BlockSize }
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 16-byte buffer src using the key k func (c *aesCipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) }
// and stores the result in dst.
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.enc, dst, src) }
// Decrypt decrypts the 16-byte buffer src using the key k func (c *aesCipher) Decrypt(dst, src []byte) { decryptBlock(c.dec, dst, src) }
// and stores the result in dst.
func (c *Cipher) Decrypt(dst, src []byte) { decryptBlock(c.dec, dst, src) }
// Reset zeros the key data, so that it will no longer
// appear in the process's memory.
func (c *Cipher) Reset() {
for i := 0; i < len(c.enc); i++ {
c.enc[i] = 0
}
for i := 0; i < len(c.dec); i++ {
c.dec[i] = 0
}
}
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
// Special Publication 800-38A, ``Recommendation for Block Cipher // Special Publication 800-38A, ``Recommendation for Block Cipher
// Modes of Operation,'' 2001 Edition, pp. 24-29. // Modes of Operation,'' 2001 Edition, pp. 24-29.
package cipher package cipher_test
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher"
"testing" "testing"
) )
...@@ -72,14 +73,14 @@ func TestCBC_AES(t *testing.T) { ...@@ -72,14 +73,14 @@ func TestCBC_AES(t *testing.T) {
continue continue
} }
encrypter := NewCBCEncrypter(c, tt.iv) encrypter := cipher.NewCBCEncrypter(c, tt.iv)
d := make([]byte, len(tt.in)) d := make([]byte, len(tt.in))
encrypter.CryptBlocks(d, tt.in) encrypter.CryptBlocks(d, tt.in)
if !bytes.Equal(tt.out, d) { if !bytes.Equal(tt.out, d) {
t.Errorf("%s: CBCEncrypter\nhave %x\nwant %x", test, d, tt.out) t.Errorf("%s: CBCEncrypter\nhave %x\nwant %x", test, d, tt.out)
} }
decrypter := NewCBCDecrypter(c, tt.iv) decrypter := cipher.NewCBCDecrypter(c, tt.iv)
p := make([]byte, len(d)) p := make([]byte, len(d))
decrypter.CryptBlocks(p, d) decrypter.CryptBlocks(p, d)
if !bytes.Equal(tt.in, p) { if !bytes.Equal(tt.in, p) {
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package cipher package cipher_test
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher"
"crypto/rand" "crypto/rand"
"testing" "testing"
) )
...@@ -21,11 +22,11 @@ func TestCFB(t *testing.T) { ...@@ -21,11 +22,11 @@ func TestCFB(t *testing.T) {
plaintext := []byte("this is the plaintext") plaintext := []byte("this is the plaintext")
iv := make([]byte, block.BlockSize()) iv := make([]byte, block.BlockSize())
rand.Reader.Read(iv) rand.Reader.Read(iv)
cfb := NewCFBEncrypter(block, iv) cfb := cipher.NewCFBEncrypter(block, iv)
ciphertext := make([]byte, len(plaintext)) ciphertext := make([]byte, len(plaintext))
cfb.XORKeyStream(ciphertext, plaintext) cfb.XORKeyStream(ciphertext, plaintext)
cfbdec := NewCFBDecrypter(block, iv) cfbdec := cipher.NewCFBDecrypter(block, iv)
plaintextCopy := make([]byte, len(plaintext)) plaintextCopy := make([]byte, len(plaintext))
cfbdec.XORKeyStream(plaintextCopy, ciphertext) cfbdec.XORKeyStream(plaintextCopy, ciphertext)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package cipher package cipher_test
// Common values for tests. // Common values for tests.
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
// Special Publication 800-38A, ``Recommendation for Block Cipher // Special Publication 800-38A, ``Recommendation for Block Cipher
// Modes of Operation,'' 2001 Edition, pp. 55-58. // Modes of Operation,'' 2001 Edition, pp. 55-58.
package cipher package cipher_test
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher"
"testing" "testing"
) )
...@@ -76,7 +77,7 @@ func TestCTR_AES(t *testing.T) { ...@@ -76,7 +77,7 @@ func TestCTR_AES(t *testing.T) {
for j := 0; j <= 5; j += 5 { for j := 0; j <= 5; j += 5 {
in := tt.in[0 : len(tt.in)-j] in := tt.in[0 : len(tt.in)-j]
ctr := NewCTR(c, tt.iv) ctr := cipher.NewCTR(c, tt.iv)
encrypted := make([]byte, len(in)) encrypted := make([]byte, len(in))
ctr.XORKeyStream(encrypted, in) ctr.XORKeyStream(encrypted, in)
if out := tt.out[0:len(in)]; !bytes.Equal(out, encrypted) { if out := tt.out[0:len(in)]; !bytes.Equal(out, encrypted) {
...@@ -86,7 +87,7 @@ func TestCTR_AES(t *testing.T) { ...@@ -86,7 +87,7 @@ func TestCTR_AES(t *testing.T) {
for j := 0; j <= 7; j += 7 { for j := 0; j <= 7; j += 7 {
in := tt.out[0 : len(tt.out)-j] in := tt.out[0 : len(tt.out)-j]
ctr := NewCTR(c, tt.iv) ctr := cipher.NewCTR(c, tt.iv)
plain := make([]byte, len(in)) plain := make([]byte, len(in))
ctr.XORKeyStream(plain, in) ctr.XORKeyStream(plain, in)
if out := tt.in[0:len(in)]; !bytes.Equal(out, plain) { if out := tt.in[0:len(in)]; !bytes.Equal(out, plain) {
......
...@@ -8,11 +8,12 @@ ...@@ -8,11 +8,12 @@
// Special Publication 800-38A, ``Recommendation for Block Cipher // Special Publication 800-38A, ``Recommendation for Block Cipher
// Modes of Operation,'' 2001 Edition, pp. 52-55. // Modes of Operation,'' 2001 Edition, pp. 52-55.
package cipher package cipher_test
import ( import (
"bytes" "bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher"
"testing" "testing"
) )
...@@ -76,7 +77,7 @@ func TestOFB(t *testing.T) { ...@@ -76,7 +77,7 @@ func TestOFB(t *testing.T) {
for j := 0; j <= 5; j += 5 { for j := 0; j <= 5; j += 5 {
plaintext := tt.in[0 : len(tt.in)-j] plaintext := tt.in[0 : len(tt.in)-j]
ofb := NewOFB(c, tt.iv) ofb := cipher.NewOFB(c, tt.iv)
ciphertext := make([]byte, len(plaintext)) ciphertext := make([]byte, len(plaintext))
ofb.XORKeyStream(ciphertext, plaintext) ofb.XORKeyStream(ciphertext, plaintext)
if !bytes.Equal(ciphertext, tt.out[:len(plaintext)]) { if !bytes.Equal(ciphertext, tt.out[:len(plaintext)]) {
...@@ -86,7 +87,7 @@ func TestOFB(t *testing.T) { ...@@ -86,7 +87,7 @@ func TestOFB(t *testing.T) {
for j := 0; j <= 5; j += 5 { for j := 0; j <= 5; j += 5 {
ciphertext := tt.out[0 : len(tt.in)-j] ciphertext := tt.out[0 : len(tt.in)-j]
ofb := NewOFB(c, tt.iv) ofb := cipher.NewOFB(c, tt.iv)
plaintext := make([]byte, len(ciphertext)) plaintext := make([]byte, len(ciphertext))
ofb.XORKeyStream(plaintext, ciphertext) ofb.XORKeyStream(plaintext, ciphertext)
if !bytes.Equal(plaintext, tt.in[:len(ciphertext)]) { if !bytes.Equal(plaintext, tt.in[:len(ciphertext)]) {
......
...@@ -79,7 +79,7 @@ func ksRotate(in uint32) (out []uint32) { ...@@ -79,7 +79,7 @@ func ksRotate(in uint32) (out []uint32) {
} }
// creates 16 56-bit subkeys from the original key // creates 16 56-bit subkeys from the original key
func (c *Cipher) generateSubkeys(keyBytes []byte) { func (c *desCipher) generateSubkeys(keyBytes []byte) {
// apply PC1 permutation to key // apply PC1 permutation to key
key := binary.BigEndian.Uint64(keyBytes) key := binary.BigEndian.Uint64(keyBytes)
permutedKey := permuteBlock(key, permutedChoice1[:]) permutedKey := permuteBlock(key, permutedChoice1[:])
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
package des package des
import "strconv" import (
"crypto/cipher"
"strconv"
)
// The DES block size in bytes. // The DES block size in bytes.
const BlockSize = 8 const BlockSize = 8
...@@ -15,86 +18,56 @@ func (k KeySizeError) Error() string { ...@@ -15,86 +18,56 @@ func (k KeySizeError) Error() string {
return "crypto/des: invalid key size " + strconv.Itoa(int(k)) return "crypto/des: invalid key size " + strconv.Itoa(int(k))
} }
// Cipher is an instance of DES encryption. // desCipher is an instance of DES encryption.
type Cipher struct { type desCipher struct {
subkeys [16]uint64 subkeys [16]uint64
} }
// NewCipher creates and returns a new Cipher. // NewCipher creates and returns a new cipher.Block.
func NewCipher(key []byte) (*Cipher, error) { func NewCipher(key []byte) (cipher.Block, error) {
if len(key) != 8 { if len(key) != 8 {
return nil, KeySizeError(len(key)) return nil, KeySizeError(len(key))
} }
c := new(Cipher) c := new(desCipher)
c.generateSubkeys(key) c.generateSubkeys(key)
return c, nil return c, nil
} }
// BlockSize returns the DES block size, 8 bytes. func (c *desCipher) BlockSize() int { return BlockSize }
func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypt encrypts the 8-byte buffer src and stores the result in dst. func (c *desCipher) Encrypt(dst, src []byte) { encryptBlock(c.subkeys[:], dst, src) }
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.subkeys[:], dst, src) }
// Decrypt decrypts the 8-byte buffer src and stores the result in dst. func (c *desCipher) Decrypt(dst, src []byte) { decryptBlock(c.subkeys[:], dst, src) }
func (c *Cipher) Decrypt(dst, src []byte) { decryptBlock(c.subkeys[:], dst, src) }
// Reset zeros the key data, so that it will no longer // A tripleDESCipher is an instance of TripleDES encryption.
// appear in the process's memory. type tripleDESCipher struct {
func (c *Cipher) Reset() { cipher1, cipher2, cipher3 desCipher
for i := 0; i < len(c.subkeys); i++ {
c.subkeys[i] = 0
}
}
// A TripleDESCipher is an instance of TripleDES encryption.
type TripleDESCipher struct {
cipher1, cipher2, cipher3 Cipher
} }
// NewCipher creates and returns a new Cipher. // NewTripleDESCipher creates and returns a new cipher.Block.
func NewTripleDESCipher(key []byte) (*TripleDESCipher, error) { func NewTripleDESCipher(key []byte) (cipher.Block, error) {
if len(key) != 24 { if len(key) != 24 {
return nil, KeySizeError(len(key)) return nil, KeySizeError(len(key))
} }
c := new(TripleDESCipher) c := new(tripleDESCipher)
c.cipher1.generateSubkeys(key[:8]) c.cipher1.generateSubkeys(key[:8])
c.cipher2.generateSubkeys(key[8:16]) c.cipher2.generateSubkeys(key[8:16])
c.cipher3.generateSubkeys(key[16:]) c.cipher3.generateSubkeys(key[16:])
return c, nil return c, nil
} }
// BlockSize returns the TripleDES block size, 8 bytes. func (c *tripleDESCipher) BlockSize() int { return BlockSize }
// It is necessary to satisfy the Block interface in the
// package "crypto/cipher".
func (c *TripleDESCipher) BlockSize() int { return BlockSize }
// Encrypts the 8-byte buffer src and stores the result in dst. func (c *tripleDESCipher) Encrypt(dst, src []byte) {
// Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *TripleDESCipher) Encrypt(dst, src []byte) {
c.cipher1.Encrypt(dst, src) c.cipher1.Encrypt(dst, src)
c.cipher2.Decrypt(dst, dst) c.cipher2.Decrypt(dst, dst)
c.cipher3.Encrypt(dst, dst) c.cipher3.Encrypt(dst, dst)
} }
// Decrypts the 8-byte buffer src and stores the result in dst. func (c *tripleDESCipher) Decrypt(dst, src []byte) {
func (c *TripleDESCipher) Decrypt(dst, src []byte) {
c.cipher3.Decrypt(dst, src) c.cipher3.Decrypt(dst, src)
c.cipher2.Encrypt(dst, dst) c.cipher2.Encrypt(dst, dst)
c.cipher1.Decrypt(dst, dst) c.cipher1.Decrypt(dst, dst)
} }
// Reset zeros the key data, so that it will no longer
// appear in the process's memory.
func (c *TripleDESCipher) Reset() {
c.cipher1.Reset()
c.cipher2.Reset()
c.cipher3.Reset()
}
...@@ -1260,11 +1260,19 @@ var tableA4Tests = []CryptTest{ ...@@ -1260,11 +1260,19 @@ var tableA4Tests = []CryptTest{
[]byte{0x63, 0xfa, 0xc0, 0xd0, 0x34, 0xd9, 0xf7, 0x93}}, []byte{0x63, 0xfa, 0xc0, 0xd0, 0x34, 0xd9, 0xf7, 0x93}},
} }
func newCipher(key []byte) *desCipher {
c, err := NewCipher(key)
if err != nil {
panic("NewCipher failed: " + err.Error())
}
return c.(*desCipher)
}
// Use the known weak keys to test DES implementation // Use the known weak keys to test DES implementation
func TestWeakKeys(t *testing.T) { func TestWeakKeys(t *testing.T) {
for i, tt := range weakKeyTests { for i, tt := range weakKeyTests {
var encrypt = func(in []byte) (out []byte) { var encrypt = func(in []byte) (out []byte) {
c, _ := NewCipher(tt.key) c := newCipher(tt.key)
out = make([]byte, len(in)) out = make([]byte, len(in))
encryptBlock(c.subkeys[:], out, in) encryptBlock(c.subkeys[:], out, in)
return return
...@@ -1285,7 +1293,7 @@ func TestWeakKeys(t *testing.T) { ...@@ -1285,7 +1293,7 @@ func TestWeakKeys(t *testing.T) {
func TestSemiWeakKeyPairs(t *testing.T) { func TestSemiWeakKeyPairs(t *testing.T) {
for i, tt := range semiWeakKeyTests { for i, tt := range semiWeakKeyTests {
var encrypt = func(key, in []byte) (out []byte) { var encrypt = func(key, in []byte) (out []byte) {
c, _ := NewCipher(key) c := newCipher(key)
out = make([]byte, len(in)) out = make([]byte, len(in))
encryptBlock(c.subkeys[:], out, in) encryptBlock(c.subkeys[:], out, in)
return return
...@@ -1305,7 +1313,7 @@ func TestSemiWeakKeyPairs(t *testing.T) { ...@@ -1305,7 +1313,7 @@ func TestSemiWeakKeyPairs(t *testing.T) {
func TestDESEncryptBlock(t *testing.T) { func TestDESEncryptBlock(t *testing.T) {
for i, tt := range encryptDESTests { for i, tt := range encryptDESTests {
c, _ := NewCipher(tt.key) c := newCipher(tt.key)
out := make([]byte, len(tt.in)) out := make([]byte, len(tt.in))
encryptBlock(c.subkeys[:], out, tt.in) encryptBlock(c.subkeys[:], out, tt.in)
...@@ -1317,7 +1325,7 @@ func TestDESEncryptBlock(t *testing.T) { ...@@ -1317,7 +1325,7 @@ func TestDESEncryptBlock(t *testing.T) {
func TestDESDecryptBlock(t *testing.T) { func TestDESDecryptBlock(t *testing.T) {
for i, tt := range encryptDESTests { for i, tt := range encryptDESTests {
c, _ := NewCipher(tt.key) c := newCipher(tt.key)
plain := make([]byte, len(tt.in)) plain := make([]byte, len(tt.in))
decryptBlock(c.subkeys[:], plain, tt.out) decryptBlock(c.subkeys[:], plain, tt.out)
......
...@@ -29,17 +29,11 @@ type PrivateKey struct { ...@@ -29,17 +29,11 @@ type PrivateKey struct {
X *big.Int X *big.Int
} }
type invalidPublicKeyError int
func (invalidPublicKeyError) Error() string {
return "crypto/dsa: invalid public key"
}
// ErrInvalidPublicKey results when a public key is not usable by this code. // ErrInvalidPublicKey results when a public key is not usable by this code.
// FIPS is quite strict about the format of DSA keys, but other code may be // FIPS is quite strict about the format of DSA keys, but other code may be
// less so. Thus, when using keys which may have been generated by other code, // less so. Thus, when using keys which may have been generated by other code,
// this error must be handled. // this error must be handled.
var ErrInvalidPublicKey error = invalidPublicKeyError(0) var ErrInvalidPublicKey = errors.New("crypto/dsa: invalid public key")
// ParameterSizes is a enumeration of the acceptable bit lengths of the primes // ParameterSizes is a enumeration of the acceptable bit lengths of the primes
// in a set of DSA parameters. See FIPS 186-3, section 4.2. // in a set of DSA parameters. See FIPS 186-3, section 4.2.
......
...@@ -22,7 +22,7 @@ func TestRead(t *testing.T) { ...@@ -22,7 +22,7 @@ func TestRead(t *testing.T) {
} }
var z bytes.Buffer var z bytes.Buffer
f := flate.NewWriter(&z, 5) f, _ := flate.NewWriter(&z, 5)
f.Write(b) f.Write(b)
f.Close() f.Close()
if z.Len() < len(b)*99/100 { if z.Len() < len(b)*99/100 {
......
...@@ -12,6 +12,7 @@ package rand ...@@ -12,6 +12,7 @@ package rand
import ( import (
"bufio" "bufio"
"crypto/aes" "crypto/aes"
"crypto/cipher"
"io" "io"
"os" "os"
"sync" "sync"
...@@ -66,7 +67,7 @@ func newReader(entropy io.Reader) io.Reader { ...@@ -66,7 +67,7 @@ func newReader(entropy io.Reader) io.Reader {
type reader struct { type reader struct {
mu sync.Mutex mu sync.Mutex
budget int // number of bytes that can be generated budget int // number of bytes that can be generated
cipher *aes.Cipher cipher cipher.Block
entropy io.Reader entropy io.Reader
time, seed, dst, key [aes.BlockSize]byte time, seed, dst, key [aes.BlockSize]byte
} }
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, err error) { func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, err error) {
k := (pub.N.BitLen() + 7) / 8 k := (pub.N.BitLen() + 7) / 8
if len(msg) > k-11 { if len(msg) > k-11 {
err = MessageTooLongError{} err = ErrMessageTooLong
return return
} }
...@@ -47,7 +47,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, er ...@@ -47,7 +47,7 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) (out []byte, er
func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out []byte, err error) { func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out []byte, err error) {
valid, out, err := decryptPKCS1v15(rand, priv, ciphertext) valid, out, err := decryptPKCS1v15(rand, priv, ciphertext)
if err == nil && valid == 0 { if err == nil && valid == 0 {
err = DecryptionError{} err = ErrDecryption
} }
return return
...@@ -69,7 +69,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out [ ...@@ -69,7 +69,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out [
func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) (err error) { func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) (err error) {
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if k-(len(key)+3+8) < 0 { if k-(len(key)+3+8) < 0 {
err = DecryptionError{} err = ErrDecryption
return return
} }
...@@ -86,7 +86,7 @@ func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []by ...@@ -86,7 +86,7 @@ func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []by
func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, msg []byte, err error) { func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, msg []byte, err error) {
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if k < 11 { if k < 11 {
err = DecryptionError{} err = ErrDecryption
return return
} }
...@@ -170,7 +170,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b ...@@ -170,7 +170,7 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
tLen := len(prefix) + hashLen tLen := len(prefix) + hashLen
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if k < tLen+11 { if k < tLen+11 {
return nil, MessageTooLongError{} return nil, ErrMessageTooLong
} }
// EM = 0x00 || 0x01 || PS || 0x00 || T // EM = 0x00 || 0x01 || PS || 0x00 || T
...@@ -203,7 +203,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) ...@@ -203,7 +203,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
tLen := len(prefix) + hashLen tLen := len(prefix) + hashLen
k := (pub.N.BitLen() + 7) / 8 k := (pub.N.BitLen() + 7) / 8
if k < tLen+11 { if k < tLen+11 {
err = VerificationError{} err = ErrVerification
return return
} }
...@@ -223,7 +223,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) ...@@ -223,7 +223,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
} }
if ok != 1 { if ok != 1 {
return VerificationError{} return ErrVerification
} }
return nil return nil
......
...@@ -206,13 +206,9 @@ func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { ...@@ -206,13 +206,9 @@ func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
} }
} }
// MessageTooLongError is returned when attempting to encrypt a message which // ErrMessageTooLong is returned when attempting to encrypt a message which is
// is too large for the size of the public key. // too large for the size of the public key.
type MessageTooLongError struct{} var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size")
func (MessageTooLongError) Error() string {
return "message too long for RSA public key size"
}
func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int { func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
e := big.NewInt(int64(pub.E)) e := big.NewInt(int64(pub.E))
...@@ -227,7 +223,7 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l ...@@ -227,7 +223,7 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
hash.Reset() hash.Reset()
k := (pub.N.BitLen() + 7) / 8 k := (pub.N.BitLen() + 7) / 8
if len(msg) > k-2*hash.Size()-2 { if len(msg) > k-2*hash.Size()-2 {
err = MessageTooLongError{} err = ErrMessageTooLong
return return
} }
...@@ -266,17 +262,13 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l ...@@ -266,17 +262,13 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
return return
} }
// A DecryptionError represents a failure to decrypt a message. // ErrDecryption represents a failure to decrypt a message.
// It is deliberately vague to avoid adaptive attacks. // It is deliberately vague to avoid adaptive attacks.
type DecryptionError struct{} var ErrDecryption = errors.New("crypto/rsa: decryption error")
func (DecryptionError) Error() string { return "RSA decryption error" } // ErrVerification represents a failure to verify a signature.
// A VerificationError represents a failure to verify a signature.
// It is deliberately vague to avoid adaptive attacks. // It is deliberately vague to avoid adaptive attacks.
type VerificationError struct{} var ErrVerification = errors.New("crypto/rsa: verification error")
func (VerificationError) Error() string { return "RSA verification error" }
// modInverse returns ia, the inverse of a in the multiplicative group of prime // modInverse returns ia, the inverse of a in the multiplicative group of prime
// order n. It requires that a be a member of the group (i.e. less than n). // order n. It requires that a be a member of the group (i.e. less than n).
...@@ -338,7 +330,7 @@ func (priv *PrivateKey) Precompute() { ...@@ -338,7 +330,7 @@ func (priv *PrivateKey) Precompute() {
func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) { func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
// TODO(agl): can we get away with reusing blinds? // TODO(agl): can we get away with reusing blinds?
if c.Cmp(priv.N) > 0 { if c.Cmp(priv.N) > 0 {
err = DecryptionError{} err = ErrDecryption
return return
} }
...@@ -417,7 +409,7 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext ...@@ -417,7 +409,7 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if len(ciphertext) > k || if len(ciphertext) > k ||
k < hash.Size()*2+2 { k < hash.Size()*2+2 {
err = DecryptionError{} err = ErrDecryption
return return
} }
...@@ -473,7 +465,7 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext ...@@ -473,7 +465,7 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
} }
if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 { if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
err = DecryptionError{} err = ErrDecryption
return return
} }
......
...@@ -87,9 +87,9 @@ func (c *Conn) RemoteAddr() net.Addr { ...@@ -87,9 +87,9 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }
// SetDeadline sets the read deadline associated with the connection. // SetDeadline sets the read and write deadlines associated with the connection.
// There is no write deadline. // A zero value for t means Read and Write will not time out.
// A zero value for t means Read will not time out. // After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error { func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t) return c.conn.SetDeadline(t)
} }
...@@ -100,10 +100,11 @@ func (c *Conn) SetReadDeadline(t time.Time) error { ...@@ -100,10 +100,11 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
// SetWriteDeadline exists to satisfy the net.Conn interface // SetWriteDeadline sets the write deadline on the underlying conneciton.
// but is not implemented by TLS. It always returns an error. // A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
return errors.New("TLS does not support SetWriteDeadline") return c.conn.SetWriteDeadline(t)
} }
// A halfConn represents one direction of the record layer // A halfConn represents one direction of the record layer
...@@ -726,9 +727,13 @@ func (c *Conn) readHandshake() (interface{}, error) { ...@@ -726,9 +727,13 @@ func (c *Conn) readHandshake() (interface{}, error) {
} }
// Write writes data to the connection. // Write writes data to the connection.
func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (int, error) {
if err = c.Handshake(); err != nil { if c.err != nil {
return return 0, c.err
}
if c.err = c.Handshake(); c.err != nil {
return 0, c.err
} }
c.out.Lock() c.out.Lock()
...@@ -737,10 +742,10 @@ func (c *Conn) Write(b []byte) (n int, err error) { ...@@ -737,10 +742,10 @@ func (c *Conn) Write(b []byte) (n int, err error) {
if !c.handshakeComplete { if !c.handshakeComplete {
return 0, alertInternalError return 0, alertInternalError
} }
if c.err != nil {
return 0, c.err var n int
} n, c.err = c.writeRecord(recordTypeApplicationData, b)
return c.writeRecord(recordTypeApplicationData, b) return n, c.err
} }
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
// Generate a self-signed X.509 certificate for a TLS server. Outputs to // Generate a self-signed X.509 certificate for a TLS server. Outputs to
// 'cert.pem' and 'key.pem' and will overwrite existing files. // 'cert.pem' and 'key.pem' and will overwrite existing files.
......
...@@ -62,7 +62,7 @@ func TestRunClient(t *testing.T) { ...@@ -62,7 +62,7 @@ func TestRunClient(t *testing.T) {
// Script of interaction with gnutls implementation. // Script of interaction with gnutls implementation.
// The values for this test are obtained by building and running in client mode: // The values for this test are obtained by building and running in client mode:
// % gotest -test.run "TestRunClient" -connect // % go test -run "TestRunClient" -connect
// and then: // and then:
// % gnutls-serv -p 10443 --debug 100 --x509keyfile key.pem --x509certfile cert.pem -a > /tmp/log 2>&1 // % gnutls-serv -p 10443 --debug 100 --x509keyfile key.pem --x509certfile cert.pem -a > /tmp/log 2>&1
// % python parse-gnutls-cli-debug-log.py < /tmp/log // % python parse-gnutls-cli-debug-log.py < /tmp/log
......
...@@ -284,7 +284,7 @@ func loadPEMCert(in string) *x509.Certificate { ...@@ -284,7 +284,7 @@ func loadPEMCert(in string) *x509.Certificate {
// Script of interaction with gnutls implementation. // Script of interaction with gnutls implementation.
// The values for this test are obtained by building and running in server mode: // The values for this test are obtained by building and running in server mode:
// % gotest -test.run "TestRunServer" -serve // % go test -run "TestRunServer" -serve
// and then: // and then:
// % gnutls-cli --insecure --debug 100 -p 10443 localhost > /tmp/log 2>&1 // % gnutls-cli --insecure --debug 100 -p 10443 localhost > /tmp/log 2>&1
// % python parse-gnutls-cli-debug-log.py < /tmp/log // % python parse-gnutls-cli-debug-log.py < /tmp/log
...@@ -949,7 +949,7 @@ var sslv3ServerScript = [][]byte{ ...@@ -949,7 +949,7 @@ var sslv3ServerScript = [][]byte{
var clientauthTests = []clientauthTest{ var clientauthTests = []clientauthTest{
// Server doesn't asks for cert // Server doesn't asks for cert
// gotest -test.run "TestRunServer" -serve -clientauth 0 // go test -run "TestRunServer" -serve -clientauth 0
// gnutls-cli --insecure --debug 100 -p 10443 localhost 2>&1 | // gnutls-cli --insecure --debug 100 -p 10443 localhost 2>&1 |
// python parse-gnutls-cli-debug-log.py // python parse-gnutls-cli-debug-log.py
{"NoClientCert", NoClientCert, nil, {"NoClientCert", NoClientCert, nil,
...@@ -1115,7 +1115,7 @@ var clientauthTests = []clientauthTest{ ...@@ -1115,7 +1115,7 @@ var clientauthTests = []clientauthTest{
0x03, 0x11, 0x43, 0x3e, 0xee, 0xb7, 0x4d, 0x69, 0x03, 0x11, 0x43, 0x3e, 0xee, 0xb7, 0x4d, 0x69,
}}}, }}},
// Server asks for cert with empty CA list, client doesn't give it. // Server asks for cert with empty CA list, client doesn't give it.
// gotest -test.run "TestRunServer" -serve -clientauth 1 // go test -run "TestRunServer" -serve -clientauth 1
// gnutls-cli --insecure --debug 100 -p 10443 localhost // gnutls-cli --insecure --debug 100 -p 10443 localhost
{"RequestClientCert, none given", RequestClientCert, nil, {"RequestClientCert, none given", RequestClientCert, nil,
[][]byte{{ [][]byte{{
...@@ -1282,7 +1282,7 @@ var clientauthTests = []clientauthTest{ ...@@ -1282,7 +1282,7 @@ var clientauthTests = []clientauthTest{
0xf4, 0x70, 0xcc, 0xb4, 0xed, 0x07, 0x76, 0x3a, 0xf4, 0x70, 0xcc, 0xb4, 0xed, 0x07, 0x76, 0x3a,
}}}, }}},
// Server asks for cert with empty CA list, client gives one // Server asks for cert with empty CA list, client gives one
// gotest -test.run "TestRunServer" -serve -clientauth 1 // go test -run "TestRunServer" -serve -clientauth 1
// gnutls-cli --insecure --debug 100 -p 10443 localhost // gnutls-cli --insecure --debug 100 -p 10443 localhost
{"RequestClientCert, client gives it", RequestClientCert, {"RequestClientCert, client gives it", RequestClientCert,
[]*x509.Certificate{clicert}, []*x509.Certificate{clicert},
......
...@@ -327,13 +327,9 @@ type Certificate struct { ...@@ -327,13 +327,9 @@ type Certificate struct {
PolicyIdentifiers []asn1.ObjectIdentifier PolicyIdentifiers []asn1.ObjectIdentifier
} }
// UnsupportedAlgorithmError results from attempting to perform an operation // ErrUnsupportedAlgorithm results from attempting to perform an operation that
// that involves algorithms that are not currently implemented. // involves algorithms that are not currently implemented.
type UnsupportedAlgorithmError struct{} var ErrUnsupportedAlgorithm = errors.New("crypto/x509: cannot verify signature: algorithm unimplemented")
func (UnsupportedAlgorithmError) Error() string {
return "cannot verify signature: algorithm unimplemented"
}
// ConstraintViolationError results when a requested usage is not permitted by // ConstraintViolationError results when a requested usage is not permitted by
// a certificate. For example: checking a signature when the public key isn't a // a certificate. For example: checking a signature when the public key isn't a
...@@ -341,7 +337,7 @@ func (UnsupportedAlgorithmError) Error() string { ...@@ -341,7 +337,7 @@ func (UnsupportedAlgorithmError) Error() string {
type ConstraintViolationError struct{} type ConstraintViolationError struct{}
func (ConstraintViolationError) Error() string { func (ConstraintViolationError) Error() string {
return "invalid signature: parent certificate cannot sign this kind of certificate" return "crypto/x509: invalid signature: parent certificate cannot sign this kind of certificate"
} }
func (c *Certificate) Equal(other *Certificate) bool { func (c *Certificate) Equal(other *Certificate) bool {
...@@ -366,7 +362,7 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err error) { ...@@ -366,7 +362,7 @@ func (c *Certificate) CheckSignatureFrom(parent *Certificate) (err error) {
} }
if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm { if parent.PublicKeyAlgorithm == UnknownPublicKeyAlgorithm {
return UnsupportedAlgorithmError{} return ErrUnsupportedAlgorithm
} }
// TODO(agl): don't ignore the path length constraint. // TODO(agl): don't ignore the path length constraint.
...@@ -389,12 +385,12 @@ func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature ...@@ -389,12 +385,12 @@ func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature
case SHA512WithRSA: case SHA512WithRSA:
hashType = crypto.SHA512 hashType = crypto.SHA512
default: default:
return UnsupportedAlgorithmError{} return ErrUnsupportedAlgorithm
} }
h := hashType.New() h := hashType.New()
if h == nil { if h == nil {
return UnsupportedAlgorithmError{} return ErrUnsupportedAlgorithm
} }
h.Write(signed) h.Write(signed)
...@@ -416,7 +412,7 @@ func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature ...@@ -416,7 +412,7 @@ func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature
} }
return return
} }
return UnsupportedAlgorithmError{} return ErrUnsupportedAlgorithm
} }
// CheckCRLSignature checks that the signature in crl is from c. // CheckCRLSignature checks that the signature in crl is from c.
...@@ -795,7 +791,7 @@ var ( ...@@ -795,7 +791,7 @@ var (
) )
func buildExtensions(template *Certificate) (ret []pkix.Extension, err error) { func buildExtensions(template *Certificate) (ret []pkix.Extension, err error) {
ret = make([]pkix.Extension, 7 /* maximum number of elements. */ ) ret = make([]pkix.Extension, 7 /* maximum number of elements. */)
n := 0 n := 0
if template.KeyUsage != 0 { if template.KeyUsage != 0 {
......
...@@ -90,8 +90,8 @@ func convertAssign(dest, src interface{}) error { ...@@ -90,8 +90,8 @@ func convertAssign(dest, src interface{}) error {
return nil return nil
} }
if scanner, ok := dest.(ScannerInto); ok { if scanner, ok := dest.(Scanner); ok {
return scanner.ScanInto(src) return scanner.Scan(src)
} }
dpv := reflect.ValueOf(dest) dpv := reflect.ValueOf(dest)
...@@ -110,6 +110,14 @@ func convertAssign(dest, src interface{}) error { ...@@ -110,6 +110,14 @@ func convertAssign(dest, src interface{}) error {
} }
switch dv.Kind() { switch dv.Kind() {
case reflect.Ptr:
if src == nil {
dv.Set(reflect.Zero(dv.Type()))
return nil
} else {
dv.Set(reflect.New(dv.Type().Elem()))
return convertAssign(dv.Interface(), src)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := asString(src) s := asString(src)
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
) )
var someTime = time.Unix(123, 0) var someTime = time.Unix(123, 0)
var answer int64 = 42
type conversionTest struct { type conversionTest struct {
s, d interface{} // source and destination s, d interface{} // source and destination
...@@ -27,6 +28,8 @@ type conversionTest struct { ...@@ -27,6 +28,8 @@ type conversionTest struct {
wantbool bool // used if d is of type *bool wantbool bool // used if d is of type *bool
wanterr string wanterr string
wantiface interface{} wantiface interface{}
wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr
wantnil bool // if true, *d must be *int64(nil)
} }
// Target variables for scanning into. // Target variables for scanning into.
...@@ -42,6 +45,7 @@ var ( ...@@ -42,6 +45,7 @@ var (
scanf32 float32 scanf32 float32
scanf64 float64 scanf64 float64
scantime time.Time scantime time.Time
scanptr *int64
scaniface interface{} scaniface interface{}
) )
...@@ -98,6 +102,10 @@ var conversionTests = []conversionTest{ ...@@ -98,6 +102,10 @@ var conversionTests = []conversionTest{
{s: "1.5", d: &scanf32, wantf32: float32(1.5)}, {s: "1.5", d: &scanf32, wantf32: float32(1.5)},
{s: "1.5", d: &scanf64, wantf64: float64(1.5)}, {s: "1.5", d: &scanf64, wantf64: float64(1.5)},
// Pointers
{s: interface{}(nil), d: &scanptr, wantnil: true},
{s: int64(42), d: &scanptr, wantptr: &answer},
// To interface{} // To interface{}
{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)}, {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
{s: int64(1), d: &scaniface, wantiface: int64(1)}, {s: int64(1), d: &scaniface, wantiface: int64(1)},
...@@ -107,6 +115,10 @@ var conversionTests = []conversionTest{ ...@@ -107,6 +115,10 @@ var conversionTests = []conversionTest{
{s: nil, d: &scaniface}, {s: nil, d: &scaniface},
} }
func intPtrValue(intptr interface{}) interface{} {
return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
}
func intValue(intptr interface{}) int64 { func intValue(intptr interface{}) int64 {
return reflect.Indirect(reflect.ValueOf(intptr)).Int() return reflect.Indirect(reflect.ValueOf(intptr)).Int()
} }
...@@ -162,6 +174,16 @@ func TestConversions(t *testing.T) { ...@@ -162,6 +174,16 @@ func TestConversions(t *testing.T) {
if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
} }
if ct.wantnil && *ct.d.(**int64) != nil {
errf("want nil, got %v", intPtrValue(ct.d))
}
if ct.wantptr != nil {
if *ct.d.(**int64) == nil {
errf("want pointer to %v, got nil", *ct.wantptr)
} else if *ct.wantptr != intPtrValue(ct.d) {
errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
}
}
if ifptr, ok := ct.d.(*interface{}); ok { if ifptr, ok := ct.d.(*interface{}); ok {
if !reflect.DeepEqual(ct.wantiface, scaniface) { if !reflect.DeepEqual(ct.wantiface, scaniface) {
errf("want interface %#v, got %#v", ct.wantiface, scaniface) errf("want interface %#v, got %#v", ct.wantiface, scaniface)
......
...@@ -248,6 +248,13 @@ func (defaultConverter) ConvertValue(v interface{}) (interface{}, error) { ...@@ -248,6 +248,13 @@ func (defaultConverter) ConvertValue(v interface{}) (interface{}, error) {
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
switch rv.Kind() { switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return defaultConverter{}.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
......
...@@ -18,6 +18,7 @@ type valueConverterTest struct { ...@@ -18,6 +18,7 @@ type valueConverterTest struct {
} }
var now = time.Now() var now = time.Now()
var answer int64 = 42
var valueConverterTests = []valueConverterTest{ var valueConverterTests = []valueConverterTest{
{Bool, "true", true, ""}, {Bool, "true", true, ""},
...@@ -37,6 +38,9 @@ var valueConverterTests = []valueConverterTest{ ...@@ -37,6 +38,9 @@ var valueConverterTests = []valueConverterTest{
{c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
{c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
{DefaultParameterConverter, now, now, ""}, {DefaultParameterConverter, now, now, ""},
{DefaultParameterConverter, (*int64)(nil), nil, ""},
{DefaultParameterConverter, &answer, answer, ""},
{DefaultParameterConverter, &now, now, ""},
} }
func TestValueConverters(t *testing.T) { func TestValueConverters(t *testing.T) {
......
...@@ -35,7 +35,7 @@ func Register(name string, driver driver.Driver) { ...@@ -35,7 +35,7 @@ func Register(name string, driver driver.Driver) {
type RawBytes []byte type RawBytes []byte
// NullString represents a string that may be null. // NullString represents a string that may be null.
// NullString implements the ScannerInto interface so // NullString implements the Scanner interface so
// it can be used as a scan destination: // it can be used as a scan destination:
// //
// var s NullString // var s NullString
...@@ -52,8 +52,8 @@ type NullString struct { ...@@ -52,8 +52,8 @@ type NullString struct {
Valid bool // Valid is true if String is not NULL Valid bool // Valid is true if String is not NULL
} }
// ScanInto implements the ScannerInto interface. // Scan implements the Scanner interface.
func (ns *NullString) ScanInto(value interface{}) error { func (ns *NullString) Scan(value interface{}) error {
if value == nil { if value == nil {
ns.String, ns.Valid = "", false ns.String, ns.Valid = "", false
return nil return nil
...@@ -71,15 +71,15 @@ func (ns NullString) SubsetValue() (interface{}, error) { ...@@ -71,15 +71,15 @@ func (ns NullString) SubsetValue() (interface{}, error) {
} }
// NullInt64 represents an int64 that may be null. // NullInt64 represents an int64 that may be null.
// NullInt64 implements the ScannerInto interface so // NullInt64 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString. // it can be used as a scan destination, similar to NullString.
type NullInt64 struct { type NullInt64 struct {
Int64 int64 Int64 int64
Valid bool // Valid is true if Int64 is not NULL Valid bool // Valid is true if Int64 is not NULL
} }
// ScanInto implements the ScannerInto interface. // Scan implements the Scanner interface.
func (n *NullInt64) ScanInto(value interface{}) error { func (n *NullInt64) Scan(value interface{}) error {
if value == nil { if value == nil {
n.Int64, n.Valid = 0, false n.Int64, n.Valid = 0, false
return nil return nil
...@@ -97,15 +97,15 @@ func (n NullInt64) SubsetValue() (interface{}, error) { ...@@ -97,15 +97,15 @@ func (n NullInt64) SubsetValue() (interface{}, error) {
} }
// NullFloat64 represents a float64 that may be null. // NullFloat64 represents a float64 that may be null.
// NullFloat64 implements the ScannerInto interface so // NullFloat64 implements the Scanner interface so
// it can be used as a scan destination, similar to NullString. // it can be used as a scan destination, similar to NullString.
type NullFloat64 struct { type NullFloat64 struct {
Float64 float64 Float64 float64
Valid bool // Valid is true if Float64 is not NULL Valid bool // Valid is true if Float64 is not NULL
} }
// ScanInto implements the ScannerInto interface. // Scan implements the Scanner interface.
func (n *NullFloat64) ScanInto(value interface{}) error { func (n *NullFloat64) Scan(value interface{}) error {
if value == nil { if value == nil {
n.Float64, n.Valid = 0, false n.Float64, n.Valid = 0, false
return nil return nil
...@@ -123,15 +123,15 @@ func (n NullFloat64) SubsetValue() (interface{}, error) { ...@@ -123,15 +123,15 @@ func (n NullFloat64) SubsetValue() (interface{}, error) {
} }
// NullBool represents a bool that may be null. // NullBool represents a bool that may be null.
// NullBool implements the ScannerInto interface so // NullBool implements the Scanner interface so
// it can be used as a scan destination, similar to NullString. // it can be used as a scan destination, similar to NullString.
type NullBool struct { type NullBool struct {
Bool bool Bool bool
Valid bool // Valid is true if Bool is not NULL Valid bool // Valid is true if Bool is not NULL
} }
// ScanInto implements the ScannerInto interface. // Scan implements the Scanner interface.
func (n *NullBool) ScanInto(value interface{}) error { func (n *NullBool) Scan(value interface{}) error {
if value == nil { if value == nil {
n.Bool, n.Valid = false, false n.Bool, n.Valid = false, false
return nil return nil
...@@ -148,22 +148,24 @@ func (n NullBool) SubsetValue() (interface{}, error) { ...@@ -148,22 +148,24 @@ func (n NullBool) SubsetValue() (interface{}, error) {
return n.Bool, nil return n.Bool, nil
} }
// ScannerInto is an interface used by Scan. // Scanner is an interface used by Scan.
type ScannerInto interface { type Scanner interface {
// ScanInto assigns a value from a database driver. // Scan assigns a value from a database driver.
// //
// The value will be of one of the following restricted // The src value will be of one of the following restricted
// set of types: // set of types:
// //
// int64 // int64
// float64 // float64
// bool // bool
// []byte // []byte
// string
// time.Time
// nil - for NULL values // nil - for NULL values
// //
// An error should be returned if the value can not be stored // An error should be returned if the value can not be stored
// without loss of information. // without loss of information.
ScanInto(value interface{}) error Scan(src interface{}) error
} }
// ErrNoRows is returned by Scan when QueryRow doesn't return a // ErrNoRows is returned by Scan when QueryRow doesn't return a
...@@ -368,7 +370,7 @@ func (db *DB) Begin() (*Tx, error) { ...@@ -368,7 +370,7 @@ func (db *DB) Begin() (*Tx, error) {
}, nil }, nil
} }
// DriverDatabase returns the database's underlying driver. // Driver returns the database's underlying driver.
func (db *DB) Driver() driver.Driver { func (db *DB) Driver() driver.Driver {
return db.driver return db.driver
} }
...@@ -378,7 +380,7 @@ func (db *DB) Driver() driver.Driver { ...@@ -378,7 +380,7 @@ func (db *DB) Driver() driver.Driver {
// A transaction must end with a call to Commit or Rollback. // A transaction must end with a call to Commit or Rollback.
// //
// After a call to Commit or Rollback, all operations on the // After a call to Commit or Rollback, all operations on the
// transaction fail with ErrTransactionFinished. // transaction fail with ErrTxDone.
type Tx struct { type Tx struct {
db *DB db *DB
...@@ -393,11 +395,11 @@ type Tx struct { ...@@ -393,11 +395,11 @@ type Tx struct {
// done transitions from false to true exactly once, on Commit // done transitions from false to true exactly once, on Commit
// or Rollback. once done, all operations fail with // or Rollback. once done, all operations fail with
// ErrTransactionFinished. // ErrTxDone.
done bool done bool
} }
var ErrTransactionFinished = errors.New("sql: Transaction has already been committed or rolled back") var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
func (tx *Tx) close() { func (tx *Tx) close() {
if tx.done { if tx.done {
...@@ -411,7 +413,7 @@ func (tx *Tx) close() { ...@@ -411,7 +413,7 @@ func (tx *Tx) close() {
func (tx *Tx) grabConn() (driver.Conn, error) { func (tx *Tx) grabConn() (driver.Conn, error) {
if tx.done { if tx.done {
return nil, ErrTransactionFinished return nil, ErrTxDone
} }
tx.cimu.Lock() tx.cimu.Lock()
return tx.ci, nil return tx.ci, nil
...@@ -424,7 +426,7 @@ func (tx *Tx) releaseConn() { ...@@ -424,7 +426,7 @@ func (tx *Tx) releaseConn() {
// Commit commits the transaction. // Commit commits the transaction.
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
if tx.done { if tx.done {
return ErrTransactionFinished return ErrTxDone
} }
defer tx.close() defer tx.close()
return tx.txi.Commit() return tx.txi.Commit()
...@@ -433,7 +435,7 @@ func (tx *Tx) Commit() error { ...@@ -433,7 +435,7 @@ func (tx *Tx) Commit() error {
// Rollback aborts the transaction. // Rollback aborts the transaction.
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
if tx.done { if tx.done {
return ErrTransactionFinished return ErrTxDone
} }
defer tx.close() defer tx.close()
return tx.txi.Rollback() return tx.txi.Rollback()
...@@ -523,10 +525,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -523,10 +525,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
if execer, ok := ci.(driver.Execer); ok { if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, args) resi, err := execer.Exec(query, args)
if err != nil { if err == nil {
return result{resi}, nil
}
if err != driver.ErrSkip {
return nil, err return nil, err
} }
return result{resi}, nil
} }
sti, err := ci.Prepare(query) sti, err := ci.Prepare(query)
...@@ -550,7 +554,7 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -550,7 +554,7 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
// Query executes a query that returns rows, typically a SELECT. // Query executes a query that returns rows, typically a SELECT.
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
if tx.done { if tx.done {
return nil, ErrTransactionFinished return nil, ErrTxDone
} }
stmt, err := tx.Prepare(query) stmt, err := tx.Prepare(query)
if err != nil { if err != nil {
...@@ -767,7 +771,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -767,7 +771,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
// Example usage: // Example usage:
// //
// var name string // var name string
// err := nameByUseridStmt.QueryRow(id).Scan(&s) // err := nameByUseridStmt.QueryRow(id).Scan(&name)
func (s *Stmt) QueryRow(args ...interface{}) *Row { func (s *Stmt) QueryRow(args ...interface{}) *Row {
rows, err := s.Query(args...) rows, err := s.Query(args...)
if err != nil { if err != nil {
......
...@@ -386,6 +386,38 @@ func TestNullByteSlice(t *testing.T) { ...@@ -386,6 +386,38 @@ func TestNullByteSlice(t *testing.T) {
} }
} }
func TestPointerParamsAndScans(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t|id=int32,name=nullstring")
bob := "bob"
var name *string
name = &bob
exec(t, db, "INSERT|t|id=10,name=?", name)
name = nil
exec(t, db, "INSERT|t|id=20,name=?", name)
err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
if err != nil {
t.Fatalf("querying id 10: %v", err)
}
if name == nil {
t.Errorf("id 10's name = nil; want bob")
} else if *name != "bob" {
t.Errorf("id 10's name = %q; want bob", *name)
}
err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name)
if err != nil {
t.Fatalf("querying id 20: %v", err)
}
if name != nil {
t.Errorf("id 20 = %q; want nil", *name)
}
}
func TestQueryRowClosingStmt(t *testing.T) { func TestQueryRowClosingStmt(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package macho implements access to Mach-O object files, as defined by // Package macho implements access to Mach-O object files.
// http://developer.apple.com/mac/library/documentation/DeveloperTools/Conceptual/MachORuntime/Reference/reference.html.
package macho package macho
// High level access to low level data structures. // High level access to low level data structures.
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
// Package binary implements translation between // Package binary implements translation between
// unsigned integer values and byte sequences // unsigned integer values and byte sequences
// and the reading and writing of fixed-size values. // and the reading and writing of fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
package binary package binary
import ( import (
...@@ -119,9 +122,6 @@ func (bigEndian) GoString() string { return "binary.BigEndian" } ...@@ -119,9 +122,6 @@ func (bigEndian) GoString() string { return "binary.BigEndian" }
// Read reads structured binary data from r into data. // Read reads structured binary data from r into data.
// Data must be a pointer to a fixed-size value or a slice // Data must be a pointer to a fixed-size value or a slice
// of fixed-size values. // of fixed-size values.
// A fixed-size value is either a fixed-size arithmetic
// type (int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
// Bytes read from r are decoded using the specified byte order // Bytes read from r are decoded using the specified byte order
// and written to successive fields of the data. // and written to successive fields of the data.
func Read(r io.Reader, order ByteOrder, data interface{}) error { func Read(r io.Reader, order ByteOrder, data interface{}) error {
...@@ -176,11 +176,8 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { ...@@ -176,11 +176,8 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error {
} }
// Write writes the binary representation of data into w. // Write writes the binary representation of data into w.
// Data must be a fixed-size value or a pointer to // Data must be a fixed-size value or a slice of fixed-size
// a fixed-size value. // values, or a pointer to such data.
// A fixed-size value is either a fixed-size arithmetic
// type (int8, uint8, int16, float32, complex64, ...)
// or an array or struct containing only fixed-size values.
// Bytes written to w are encoded using the specified byte order // Bytes written to w are encoded using the specified byte order
// and read from successive fields of the data. // and read from successive fields of the data.
func Write(w io.Writer, order ByteOrder, data interface{}) error { func Write(w io.Writer, order ByteOrder, data interface{}) error {
...@@ -253,6 +250,12 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error { ...@@ -253,6 +250,12 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error {
return err return err
} }
// Size returns how many bytes Write would generate to encode the value v, which
// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data.
func Size(v interface{}) int {
return dataSize(reflect.Indirect(reflect.ValueOf(v)))
}
// dataSize returns the number of bytes the actual data represented by v occupies in memory. // dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice // For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory // it returns the length of the slice times the element size and does not count the memory
...@@ -373,6 +376,7 @@ func (d *decoder) value(v reflect.Value) { ...@@ -373,6 +376,7 @@ func (d *decoder) value(v reflect.Value) {
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
d.value(v.Index(i)) d.value(v.Index(i))
} }
case reflect.Struct: case reflect.Struct:
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
...@@ -428,11 +432,13 @@ func (e *encoder) value(v reflect.Value) { ...@@ -428,11 +432,13 @@ func (e *encoder) value(v reflect.Value) {
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Index(i)) e.value(v.Index(i))
} }
case reflect.Struct: case reflect.Struct:
l := v.NumField() l := v.NumField()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
e.value(v.Field(i)) e.value(v.Field(i))
} }
case reflect.Slice: case reflect.Slice:
l := v.Len() l := v.Len()
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
......
...@@ -456,7 +456,7 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { ...@@ -456,7 +456,7 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
} }
if *(*unsafe.Pointer)(up) == nil { if *(*unsafe.Pointer)(up) == nil {
// Allocate object. // Allocate object.
*(*unsafe.Pointer)(up) = unsafe.New(rtyp) *(*unsafe.Pointer)(up) = unsafe.Pointer(reflect.New(rtyp).Pointer())
} }
return *(*uintptr)(up) return *(*uintptr)(up)
} }
...@@ -609,7 +609,7 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr, ...@@ -609,7 +609,7 @@ func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, p uintptr,
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
// the iteration. // the iteration.
v := reflect.ValueOf(unsafe.Unreflect(mtyp, unsafe.Pointer(p))) v := reflect.NewAt(mtyp, unsafe.Pointer(p)).Elem()
n := int(state.decodeUint()) n := int(state.decodeUint())
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl) key := decodeIntoValue(state, keyOp, keyIndir, allocValue(mtyp.Key()), ovfl)
...@@ -662,7 +662,7 @@ func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintpt ...@@ -662,7 +662,7 @@ func (dec *Decoder) decodeSlice(atyp reflect.Type, state *decoderState, p uintpt
// Always write a header at p. // Always write a header at p.
hdrp := (*reflect.SliceHeader)(unsafe.Pointer(p)) hdrp := (*reflect.SliceHeader)(unsafe.Pointer(p))
if hdrp.Cap < n { if hdrp.Cap < n {
hdrp.Data = uintptr(unsafe.NewArray(atyp.Elem(), n)) hdrp.Data = reflect.MakeSlice(atyp, n, n).Pointer()
hdrp.Cap = n hdrp.Cap = n
} }
hdrp.Len = n hdrp.Len = n
...@@ -969,16 +969,16 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { ...@@ -969,16 +969,16 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) {
// Caller has gotten us to within one indirection of our value. // Caller has gotten us to within one indirection of our value.
if i.indir > 0 { if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil { if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.New(ut.base) *(*unsafe.Pointer)(p) = unsafe.Pointer(reflect.New(ut.base).Pointer())
} }
} }
// Now p is a pointer to the base type. Do we need to climb out to // Now p is a pointer to the base type. Do we need to climb out to
// get to the receiver type? // get to the receiver type?
var v reflect.Value var v reflect.Value
if ut.decIndir == -1 { if ut.decIndir == -1 {
v = reflect.ValueOf(unsafe.Unreflect(rcvrType, unsafe.Pointer(&p))) v = reflect.NewAt(rcvrType, unsafe.Pointer(&p)).Elem()
} else { } else {
v = reflect.ValueOf(unsafe.Unreflect(rcvrType, p)) v = reflect.NewAt(rcvrType, p).Elem()
} }
state.dec.decodeGobDecoder(state, v) state.dec.decodeGobDecoder(state, v)
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
package main package main
// Need to compile package gob with debug.go to build this program. // Need to compile package gob with debug.go to build this program.
......
...@@ -590,7 +590,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -590,7 +590,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
// Maps cannot be accessed by moving addresses around the way // Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for // that slices etc. can. We must recover a full reflection value for
// the iteration. // the iteration.
v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) v := reflect.NewAt(t, unsafe.Pointer(p)).Elem()
mv := reflect.Indirect(v) mv := reflect.Indirect(v)
// We send zero-length (but non-nil) maps because the // We send zero-length (but non-nil) maps because the
// receiver might want to use the map. (Maps don't use append.) // receiver might want to use the map. (Maps don't use append.)
...@@ -613,7 +613,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp ...@@ -613,7 +613,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Interfaces transmit the name and contents of the concrete // Interfaces transmit the name and contents of the concrete
// value they contain. // value they contain.
v := reflect.ValueOf(unsafe.Unreflect(t, unsafe.Pointer(p))) v := reflect.NewAt(t, unsafe.Pointer(p)).Elem()
iv := reflect.Indirect(v) iv := reflect.Indirect(v)
if !state.sendZero && (!iv.IsValid() || iv.IsNil()) { if !state.sendZero && (!iv.IsValid() || iv.IsNil()) {
return return
...@@ -645,9 +645,9 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { ...@@ -645,9 +645,9 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
var v reflect.Value var v reflect.Value
if ut.encIndir == -1 { if ut.encIndir == -1 {
// Need to climb up one level to turn value into pointer. // Need to climb up one level to turn value into pointer.
v = reflect.ValueOf(unsafe.Unreflect(rt, unsafe.Pointer(&p))) v = reflect.NewAt(rt, unsafe.Pointer(&p)).Elem()
} else { } else {
v = reflect.ValueOf(unsafe.Unreflect(rt, p)) v = reflect.NewAt(rt, p).Elem()
} }
if !state.sendZero && isZero(v) { if !state.sendZero && isZero(v) {
return return
......
...@@ -87,7 +87,7 @@ func TestInvalidErr(t *testing.T) { ...@@ -87,7 +87,7 @@ func TestInvalidErr(t *testing.T) {
dst := make([]byte, DecodedLen(len(test.in))) dst := make([]byte, DecodedLen(len(test.in)))
_, err := Decode(dst, []byte(test.in)) _, err := Decode(dst, []byte(test.in))
if err == nil { if err == nil {
t.Errorf("#%d: expected error; got none") t.Errorf("#%d: expected error; got none", i)
} else if err.Error() != test.err { } else if err.Error() != test.err {
t.Errorf("#%d: got: %v want: %v", i, err, test.err) t.Errorf("#%d: got: %v want: %v", i, err, test.err)
} }
...@@ -98,7 +98,7 @@ func TestInvalidStringErr(t *testing.T) { ...@@ -98,7 +98,7 @@ func TestInvalidStringErr(t *testing.T) {
for i, test := range errTests { for i, test := range errTests {
_, err := DecodeString(test.in) _, err := DecodeString(test.in)
if err == nil { if err == nil {
t.Errorf("#%d: expected error; got none") t.Errorf("#%d: expected error; got none", i)
} else if err.Error() != test.err { } else if err.Error() != test.err {
t.Errorf("#%d: got: %v want: %v", i, err, test.err) t.Errorf("#%d: got: %v want: %v", i, err, test.err)
} }
......
...@@ -19,6 +19,9 @@ type Decoder struct { ...@@ -19,6 +19,9 @@ type Decoder struct {
} }
// NewDecoder returns a new decoder that reads from r. // NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder { func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r} return &Decoder{r: r}
} }
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
package xml package xml
import "time"
var atomValue = &Feed{ var atomValue = &Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Example Feed", Title: "Example Feed",
...@@ -24,11 +26,10 @@ var atomValue = &Feed{ ...@@ -24,11 +26,10 @@ var atomValue = &Feed{
} }
var atomXml = `` + var atomXml = `` +
`<feed xmlns="http://www.w3.org/2005/Atom">` + `<feed xmlns="http://www.w3.org/2005/Atom" updated="2003-12-13T18:30:02Z">` +
`<title>Example Feed</title>` + `<title>Example Feed</title>` +
`<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` + `<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
`<link href="http://example.org/"></link>` + `<link href="http://example.org/"></link>` +
`<updated>2003-12-13T18:30:02Z</updated>` +
`<author><name>John Doe</name><uri></uri><email></email></author>` + `<author><name>John Doe</name><uri></uri><email></email></author>` +
`<entry>` + `<entry>` +
`<title>Atom-Powered Robots Run Amok</title>` + `<title>Atom-Powered Robots Run Amok</title>` +
...@@ -40,8 +41,12 @@ var atomXml = `` + ...@@ -40,8 +41,12 @@ var atomXml = `` +
`</entry>` + `</entry>` +
`</feed>` `</feed>`
func ParseTime(str string) Time { func ParseTime(str string) time.Time {
return Time(str) t, err := time.Parse(time.RFC3339, str)
if err != nil {
panic(err)
}
return t
} }
func NewText(text string) Text { func NewText(text string) Text {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const ( const (
...@@ -52,6 +53,10 @@ const ( ...@@ -52,6 +53,10 @@ const (
// - a field with tag ",comment" is written as an XML comment, not // - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshalling procedure. It must not contain // subject to the usual marshalling procedure. It must not contain
// the "--" string within it. // the "--" string within it.
// - a field with a tag including the "omitempty" option is omitted
// if the field value is empty. The empty values are false, 0, any
// nil pointer or interface value, and any array, slice, map, or
// string of length zero.
// //
// If a field uses a tag "a>b>c", then the element c will be nested inside // If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name // parent elements a and b. Fields that appear next to each other that name
...@@ -63,6 +68,8 @@ const ( ...@@ -63,6 +68,8 @@ const (
// FirstName string `xml:"person>name>first"` // FirstName string `xml:"person>name>first"`
// LastName string `xml:"person>name>last"` // LastName string `xml:"person>name>last"`
// Age int `xml:"person>age"` // Age int `xml:"person>age"`
// Height float `xml:"person>height,omitempty"`
// Married bool `xml:"person>married"`
// } // }
// //
// xml.Marshal(&Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42}) // xml.Marshal(&Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
...@@ -76,6 +83,7 @@ const ( ...@@ -76,6 +83,7 @@ const (
// <last>Doe</last> // <last>Doe</last>
// </name> // </name>
// <age>42</age> // <age>42</age>
// <married>false</married>
// </person> // </person>
// </result> // </result>
// //
...@@ -116,6 +124,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { ...@@ -116,6 +124,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if !val.IsValid() { if !val.IsValid() {
return nil return nil
} }
if finfo != nil && finfo.flags&fOmitEmpty != 0 && isEmptyValue(val) {
return nil
}
kind := val.Kind() kind := val.Kind()
typ := val.Type() typ := val.Type()
...@@ -183,12 +194,8 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { ...@@ -183,12 +194,8 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
continue continue
} }
fv := val.FieldByIndex(finfo.idx) fv := val.FieldByIndex(finfo.idx)
switch fv.Kind() { if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
case reflect.String, reflect.Array, reflect.Slice: continue
// TODO: Should we really do this once ,omitempty is in?
if fv.Len() == 0 {
continue
}
} }
p.WriteByte(' ') p.WriteByte(' ')
p.WriteString(finfo.name) p.WriteString(finfo.name)
...@@ -217,7 +224,14 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { ...@@ -217,7 +224,14 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
return nil return nil
} }
var timeType = reflect.TypeOf(time.Time{})
func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
// Normally we don't see structs, but this can happen for an attribute.
if val.Type() == timeType {
p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
return nil
}
switch val.Kind() { switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p.WriteString(strconv.FormatInt(val.Int(), 10)) p.WriteString(strconv.FormatInt(val.Int(), 10))
...@@ -249,6 +263,10 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { ...@@ -249,6 +263,10 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
var ddBytes = []byte("--") var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
if val.Type() == timeType {
p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
return nil
}
s := parentStack{printer: p} s := parentStack{printer: p}
for i := range tinfo.fields { for i := range tinfo.fields {
finfo := &tinfo.fields[i] finfo := &tinfo.fields[i]
...@@ -378,3 +396,21 @@ type UnsupportedTypeError struct { ...@@ -378,3 +396,21 @@ type UnsupportedTypeError struct {
func (e *UnsupportedTypeError) Error() string { func (e *UnsupportedTypeError) Error() string {
return "xml: unsupported type: " + e.Type.String() return "xml: unsupported type: " + e.Type.String()
} }
func isEmptyValue(v reflect.Value) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
}
return false
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
) )
type DriveType int type DriveType int
...@@ -38,14 +39,14 @@ type NamedType string ...@@ -38,14 +39,14 @@ type NamedType string
type Port struct { type Port struct {
XMLName struct{} `xml:"port"` XMLName struct{} `xml:"port"`
Type string `xml:"type,attr"` Type string `xml:"type,attr,omitempty"`
Comment string `xml:",comment"` Comment string `xml:",comment"`
Number string `xml:",chardata"` Number string `xml:",chardata"`
} }
type Domain struct { type Domain struct {
XMLName struct{} `xml:"domain"` XMLName struct{} `xml:"domain"`
Country string `xml:",attr"` Country string `xml:",attr,omitempty"`
Name []byte `xml:",chardata"` Name []byte `xml:",chardata"`
Comment []byte `xml:",comment"` Comment []byte `xml:",comment"`
} }
...@@ -149,11 +150,33 @@ type NameInField struct { ...@@ -149,11 +150,33 @@ type NameInField struct {
type AttrTest struct { type AttrTest struct {
Int int `xml:",attr"` Int int `xml:",attr"`
Lower int `xml:"int,attr"` Named int `xml:"int,attr"`
Float float64 `xml:",attr"` Float float64 `xml:",attr"`
Uint8 uint8 `xml:",attr"` Uint8 uint8 `xml:",attr"`
Bool bool `xml:",attr"` Bool bool `xml:",attr"`
Str string `xml:",attr"` Str string `xml:",attr"`
Bytes []byte `xml:",attr"`
}
type OmitAttrTest struct {
Int int `xml:",attr,omitempty"`
Named int `xml:"int,attr,omitempty"`
Float float64 `xml:",attr,omitempty"`
Uint8 uint8 `xml:",attr,omitempty"`
Bool bool `xml:",attr,omitempty"`
Str string `xml:",attr,omitempty"`
Bytes []byte `xml:",attr,omitempty"`
}
type OmitFieldTest struct {
Int int `xml:",omitempty"`
Named int `xml:"int,omitempty"`
Float float64 `xml:",omitempty"`
Uint8 uint8 `xml:",omitempty"`
Bool bool `xml:",omitempty"`
Str string `xml:",omitempty"`
Bytes []byte `xml:",omitempty"`
Ptr *PresenceTest `xml:",omitempty"`
} }
type AnyTest struct { type AnyTest struct {
...@@ -234,6 +257,12 @@ var marshalTests = []struct { ...@@ -234,6 +257,12 @@ var marshalTests = []struct {
{Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`}, {Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
{Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`}, {Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `<Plain><V>1</V><V>2</V><V>3</V></Plain>`},
// Test time.
{
Value: &Plain{time.Unix(1e9, 123456789).UTC()},
ExpectXML: `<Plain><V>2001-09-09T01:46:40.123456789Z</V></Plain>`,
},
// A pointer to struct{} may be used to test for an element's presence. // A pointer to struct{} may be used to test for an element's presence.
{ {
Value: &PresenceTest{new(struct{})}, Value: &PresenceTest{new(struct{})},
...@@ -549,13 +578,65 @@ var marshalTests = []struct { ...@@ -549,13 +578,65 @@ var marshalTests = []struct {
{ {
Value: &AttrTest{ Value: &AttrTest{
Int: 8, Int: 8,
Lower: 9, Named: 9,
Float: 23.5,
Uint8: 255,
Bool: true,
Str: "str",
Bytes: []byte("byt"),
},
ExpectXML: `<AttrTest Int="8" int="9" Float="23.5" Uint8="255"` +
` Bool="true" Str="str" Bytes="byt"></AttrTest>`,
},
{
Value: &AttrTest{Bytes: []byte{}},
ExpectXML: `<AttrTest Int="0" int="0" Float="0" Uint8="0"` +
` Bool="false" Str="" Bytes=""></AttrTest>`,
},
{
Value: &OmitAttrTest{
Int: 8,
Named: 9,
Float: 23.5,
Uint8: 255,
Bool: true,
Str: "str",
Bytes: []byte("byt"),
},
ExpectXML: `<OmitAttrTest Int="8" int="9" Float="23.5" Uint8="255"` +
` Bool="true" Str="str" Bytes="byt"></OmitAttrTest>`,
},
{
Value: &OmitAttrTest{},
ExpectXML: `<OmitAttrTest></OmitAttrTest>`,
},
// omitempty on fields
{
Value: &OmitFieldTest{
Int: 8,
Named: 9,
Float: 23.5, Float: 23.5,
Uint8: 255, Uint8: 255,
Bool: true, Bool: true,
Str: "s", Str: "str",
Bytes: []byte("byt"),
Ptr: &PresenceTest{},
}, },
ExpectXML: `<AttrTest Int="8" int="9" Float="23.5" Uint8="255" Bool="true" Str="s"></AttrTest>`, ExpectXML: `<OmitFieldTest>` +
`<Int>8</Int>` +
`<int>9</int>` +
`<Float>23.5</Float>` +
`<Uint8>255</Uint8>` +
`<Bool>true</Bool>` +
`<Str>str</Str>` +
`<Bytes>byt</Bytes>` +
`<Ptr></Ptr>` +
`</OmitFieldTest>`,
},
{
Value: &OmitFieldTest{},
ExpectXML: `<OmitFieldTest></OmitFieldTest>`,
}, },
// Test ",any" // Test ",any"
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: // BUG(rsc): Mapping between XML elements and data structures is inherently flawed:
...@@ -270,6 +271,10 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { ...@@ -270,6 +271,10 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
v.Set(reflect.ValueOf(start.Name)) v.Set(reflect.ValueOf(start.Name))
break break
} }
if typ == timeType {
saveData = v
break
}
sv = v sv = v
tinfo, err = getTypeInfo(typ) tinfo, err = getTypeInfo(typ)
...@@ -473,6 +478,14 @@ func copyValue(dst reflect.Value, src []byte) (err error) { ...@@ -473,6 +478,14 @@ func copyValue(dst reflect.Value, src []byte) (err error) {
src = []byte{} src = []byte{}
} }
t.SetBytes(src) t.SetBytes(src)
case reflect.Struct:
if t.Type() == timeType {
tv, err := time.Parse(time.RFC3339, string(src))
if err != nil {
return err
}
t.Set(reflect.ValueOf(tv))
}
} }
return nil return nil
} }
......
...@@ -7,6 +7,7 @@ package xml ...@@ -7,6 +7,7 @@ package xml
import ( import (
"reflect" "reflect"
"testing" "testing"
"time"
) )
// Stripped down Atom feed data structures. // Stripped down Atom feed data structures.
...@@ -24,7 +25,7 @@ func TestUnmarshalFeed(t *testing.T) { ...@@ -24,7 +25,7 @@ func TestUnmarshalFeed(t *testing.T) {
// hget http://codereview.appspot.com/rss/mine/rsc // hget http://codereview.appspot.com/rss/mine/rsc
const atomFeedString = ` const atomFeedString = `
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub <feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us" updated="2009-10-04T01:35:58+00:00"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
</title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html"> </title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
An attempt at adding pubsubhubbub support to Rietveld. An attempt at adding pubsubhubbub support to Rietveld.
http://code.google.com/p/pubsubhubbub http://code.google.com/p/pubsubhubbub
...@@ -78,26 +79,26 @@ not being used from outside intra_region_diff.py. ...@@ -78,26 +79,26 @@ not being used from outside intra_region_diff.py.
</summary></entry></feed> ` </summary></entry></feed> `
type Feed struct { type Feed struct {
XMLName Name `xml:"http://www.w3.org/2005/Atom feed"` XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
Title string `xml:"title"` Title string `xml:"title"`
Id string `xml:"id"` Id string `xml:"id"`
Link []Link `xml:"link"` Link []Link `xml:"link"`
Updated Time `xml:"updated"` Updated time.Time `xml:"updated,attr"`
Author Person `xml:"author"` Author Person `xml:"author"`
Entry []Entry `xml:"entry"` Entry []Entry `xml:"entry"`
} }
type Entry struct { type Entry struct {
Title string `xml:"title"` Title string `xml:"title"`
Id string `xml:"id"` Id string `xml:"id"`
Link []Link `xml:"link"` Link []Link `xml:"link"`
Updated Time `xml:"updated"` Updated time.Time `xml:"updated"`
Author Person `xml:"author"` Author Person `xml:"author"`
Summary Text `xml:"summary"` Summary Text `xml:"summary"`
} }
type Link struct { type Link struct {
Rel string `xml:"rel,attr"` Rel string `xml:"rel,attr,omitempty"`
Href string `xml:"href,attr"` Href string `xml:"href,attr"`
} }
...@@ -109,12 +110,10 @@ type Person struct { ...@@ -109,12 +110,10 @@ type Person struct {
} }
type Text struct { type Text struct {
Type string `xml:"type,attr"` Type string `xml:"type,attr,omitempty"`
Body string `xml:",chardata"` Body string `xml:",chardata"`
} }
type Time string
var atomFeed = Feed{ var atomFeed = Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Code Review - My issues", Title: "Code Review - My issues",
...@@ -123,7 +122,7 @@ var atomFeed = Feed{ ...@@ -123,7 +122,7 @@ var atomFeed = Feed{
{Rel: "self", Href: "http://codereview.appspot.com/rss/mine/rsc"}, {Rel: "self", Href: "http://codereview.appspot.com/rss/mine/rsc"},
}, },
Id: "http://codereview.appspot.com/", Id: "http://codereview.appspot.com/",
Updated: "2009-10-04T01:35:58+00:00", Updated: ParseTime("2009-10-04T01:35:58+00:00"),
Author: Person{ Author: Person{
Name: "rietveld<>", Name: "rietveld<>",
InnerXML: "<name>rietveld&lt;&gt;</name>", InnerXML: "<name>rietveld&lt;&gt;</name>",
...@@ -134,7 +133,7 @@ var atomFeed = Feed{ ...@@ -134,7 +133,7 @@ var atomFeed = Feed{
Link: []Link{ Link: []Link{
{Rel: "alternate", Href: "http://codereview.appspot.com/126085"}, {Rel: "alternate", Href: "http://codereview.appspot.com/126085"},
}, },
Updated: "2009-10-04T01:35:58+00:00", Updated: ParseTime("2009-10-04T01:35:58+00:00"),
Author: Person{ Author: Person{
Name: "email-address-removed", Name: "email-address-removed",
InnerXML: "<name>email-address-removed</name>", InnerXML: "<name>email-address-removed</name>",
...@@ -181,7 +180,7 @@ the top of feeds.py marked NOTE(rsc). ...@@ -181,7 +180,7 @@ the top of feeds.py marked NOTE(rsc).
Link: []Link{ Link: []Link{
{Rel: "alternate", Href: "http://codereview.appspot.com/124106"}, {Rel: "alternate", Href: "http://codereview.appspot.com/124106"},
}, },
Updated: "2009-10-03T23:02:17+00:00", Updated: ParseTime("2009-10-03T23:02:17+00:00"),
Author: Person{ Author: Person{
Name: "email-address-removed", Name: "email-address-removed",
InnerXML: "<name>email-address-removed</name>", InnerXML: "<name>email-address-removed</name>",
......
...@@ -36,8 +36,7 @@ const ( ...@@ -36,8 +36,7 @@ const (
fComment fComment
fAny fAny
// TODO: fOmitEmpty
//fOmitEmpty
fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny
) )
...@@ -133,20 +132,28 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro ...@@ -133,20 +132,28 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro
finfo.flags |= fComment finfo.flags |= fComment
case "any": case "any":
finfo.flags |= fAny finfo.flags |= fAny
case "omitempty":
finfo.flags |= fOmitEmpty
} }
} }
// Validate the flags used. // Validate the flags used.
valid := true
switch mode := finfo.flags & fMode; mode { switch mode := finfo.flags & fMode; mode {
case 0: case 0:
finfo.flags |= fElement finfo.flags |= fElement
case fAttr, fCharData, fInnerXml, fComment, fAny: case fAttr, fCharData, fInnerXml, fComment, fAny:
if f.Name != "XMLName" && (tag == "" || mode == fAttr) { if f.Name == "XMLName" || tag != "" && mode != fAttr {
break valid = false
} }
fallthrough
default: default:
// This will also catch multiple modes in a single field. // This will also catch multiple modes in a single field.
valid = false
}
if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 {
valid = false
}
if !valid {
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q", return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml")) f.Name, typ, f.Tag.Get("xml"))
} }
......
...@@ -102,7 +102,7 @@ func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool { ...@@ -102,7 +102,7 @@ func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool {
} }
} }
if info.hasDecomposition() { if info.hasDecomposition() {
dcomp := rb.f.decompose(src, i) dcomp := info.decomposition()
rb.tmpBytes = inputBytes(dcomp) rb.tmpBytes = inputBytes(dcomp)
for i := 0; i < len(dcomp); { for i := 0; i < len(dcomp); {
info = rb.f.info(&rb.tmpBytes, i) info = rb.f.info(&rb.tmpBytes, i)
......
...@@ -6,25 +6,50 @@ package norm ...@@ -6,25 +6,50 @@ package norm
// This file contains Form-specific logic and wrappers for data in tables.go. // This file contains Form-specific logic and wrappers for data in tables.go.
// Rune info is stored in a separate trie per composing form. A composing form
// and its corresponding decomposing form share the same trie. Each trie maps
// a rune to a uint16. The values take two forms. For v >= 0x8000:
// bits
// 0..8: ccc
// 9..12: qcInfo (see below). isYesD is always true (no decompostion).
// 16: 1
// For v < 0x8000, the respective rune has a decomposition and v is an index
// into a byte array of UTF-8 decomposition sequences and additional info and
// has the form:
// <header> <decomp_byte>* [<tccc> [<lccc>]]
// The header contains the number of bytes in the decomposition (excluding this
// length byte). The two most significant bits of this lenght byte correspond
// to bit 2 and 3 of qcIfo (see below). The byte sequence itself starts at v+1.
// The byte sequence is followed by a trailing and leading CCC if the values
// for these are not zero. The value of v determines which ccc are appended
// to the sequences. For v < firstCCC, there are none, for v >= firstCCC,
// the seqence is followed by a trailing ccc, and for v >= firstLeadingCC
// there is an additional leading ccc.
const (
qcInfoMask = 0xF // to clear all but the relevant bits in a qcInfo
headerLenMask = 0x3F // extract the lenght value from the header byte
headerFlagsMask = 0xC0 // extract the qcInfo bits from the header byte
)
// runeInfo is a representation for the data stored in charinfoTrie.
type runeInfo struct { type runeInfo struct {
pos uint8 // start position in reorderBuffer; used in composition.go pos uint8 // start position in reorderBuffer; used in composition.go
size uint8 // length of UTF-8 encoding of this rune size uint8 // length of UTF-8 encoding of this rune
ccc uint8 // canonical combining class ccc uint8 // leading canonical combining class (ccc if not decomposition)
tccc uint8 // trailing canonical combining class (ccc if not decomposition)
flags qcInfo // quick check flags flags qcInfo // quick check flags
index uint16
} }
// functions dispatchable per form // functions dispatchable per form
type lookupFunc func(b input, i int) runeInfo type lookupFunc func(b input, i int) runeInfo
type decompFunc func(b input, i int) []byte
// formInfo holds Form-specific functions and tables. // formInfo holds Form-specific functions and tables.
type formInfo struct { type formInfo struct {
form Form form Form
composing, compatibility bool // form type composing, compatibility bool // form type
info lookupFunc
decompose decompFunc
info lookupFunc
} }
var formTable []*formInfo var formTable []*formInfo
...@@ -38,10 +63,8 @@ func init() { ...@@ -38,10 +63,8 @@ func init() {
f.form = Form(i) f.form = Form(i)
if Form(i) == NFKD || Form(i) == NFKC { if Form(i) == NFKD || Form(i) == NFKC {
f.compatibility = true f.compatibility = true
f.decompose = decomposeNFKC
f.info = lookupInfoNFKC f.info = lookupInfoNFKC
} else { } else {
f.decompose = decomposeNFC
f.info = lookupInfoNFC f.info = lookupInfoNFC
} }
if Form(i) == NFC || Form(i) == NFKC { if Form(i) == NFC || Form(i) == NFKC {
...@@ -76,8 +99,6 @@ func (i runeInfo) boundaryAfter() bool { ...@@ -76,8 +99,6 @@ func (i runeInfo) boundaryAfter() bool {
// //
// When all 4 bits are zero, the character is inert, meaning it is never // When all 4 bits are zero, the character is inert, meaning it is never
// influenced by normalization. // influenced by normalization.
//
// We pack the bits for both NFC/D and NFKC/D in one byte.
type qcInfo uint8 type qcInfo uint8
func (i runeInfo) isYesC() bool { return i.flags&0x4 == 0 } func (i runeInfo) isYesC() bool { return i.flags&0x4 == 0 }
...@@ -91,22 +112,12 @@ func (r runeInfo) isInert() bool { ...@@ -91,22 +112,12 @@ func (r runeInfo) isInert() bool {
return r.flags&0xf == 0 && r.ccc == 0 return r.flags&0xf == 0 && r.ccc == 0
} }
// Wrappers for tables.go func (r runeInfo) decomposition() []byte {
if r.index == 0 {
// The 16-bit value of the decomposition tries is an index into a byte return nil
// array of UTF-8 decomposition sequences. The first byte is the number }
// of bytes in the decomposition (excluding this length byte). The actual p := r.index
// sequence starts at the offset+1. n := decomps[p] & 0x3F
func decomposeNFC(s input, i int) []byte {
p := s.decomposeNFC(i)
n := decomps[p]
p++
return decomps[p : p+uint16(n)]
}
func decomposeNFKC(s input, i int) []byte {
p := s.decomposeNFKC(i)
n := decomps[p]
p++ p++
return decomps[p : p+uint16(n)] return decomps[p : p+uint16(n)]
} }
...@@ -124,16 +135,40 @@ func combine(a, b rune) rune { ...@@ -124,16 +135,40 @@ func combine(a, b rune) rune {
return recompMap[key] return recompMap[key]
} }
// The 16-bit character info has the following bit layout:
// 0..7 CCC value.
// 8..11 qcInfo for NFC/NFD
// 12..15 qcInfo for NFKC/NFKD
func lookupInfoNFC(b input, i int) runeInfo { func lookupInfoNFC(b input, i int) runeInfo {
v, sz := b.charinfo(i) v, sz := b.charinfoNFC(i)
return runeInfo{size: uint8(sz), ccc: uint8(v), flags: qcInfo(v >> 8)} return compInfo(v, sz)
} }
func lookupInfoNFKC(b input, i int) runeInfo { func lookupInfoNFKC(b input, i int) runeInfo {
v, sz := b.charinfo(i) v, sz := b.charinfoNFKC(i)
return runeInfo{size: uint8(sz), ccc: uint8(v), flags: qcInfo(v >> 12)} return compInfo(v, sz)
}
// compInfo converts the information contained in v and sz
// to a runeInfo. See the comment at the top of the file
// for more information on the format.
func compInfo(v uint16, sz int) runeInfo {
if v == 0 {
return runeInfo{size: uint8(sz)}
} else if v >= 0x8000 {
return runeInfo{
size: uint8(sz),
ccc: uint8(v),
tccc: uint8(v),
flags: qcInfo(v>>8) & qcInfoMask,
}
}
// has decomposition
h := decomps[v]
f := (qcInfo(h&headerFlagsMask) >> 4) | 0x1
ri := runeInfo{size: uint8(sz), flags: f, index: v}
if v >= firstCCC {
v += uint16(h&headerLenMask) + 1
ri.tccc = decomps[v]
if v >= firstLeadingCCC {
ri.ccc = decomps[v+1]
}
}
return ri
} }
...@@ -11,9 +11,8 @@ type input interface { ...@@ -11,9 +11,8 @@ type input interface {
skipNonStarter(p int) int skipNonStarter(p int) int
appendSlice(buf []byte, s, e int) []byte appendSlice(buf []byte, s, e int) []byte
copySlice(buf []byte, s, e int) copySlice(buf []byte, s, e int)
charinfo(p int) (uint16, int) charinfoNFC(p int) (uint16, int)
decomposeNFC(p int) uint16 charinfoNFKC(p int) (uint16, int)
decomposeNFKC(p int) uint16
hangul(p int) rune hangul(p int) rune
} }
...@@ -42,16 +41,12 @@ func (s inputString) copySlice(buf []byte, b, e int) { ...@@ -42,16 +41,12 @@ func (s inputString) copySlice(buf []byte, b, e int) {
copy(buf, s[b:e]) copy(buf, s[b:e])
} }
func (s inputString) charinfo(p int) (uint16, int) { func (s inputString) charinfoNFC(p int) (uint16, int) {
return charInfoTrie.lookupString(string(s[p:])) return nfcTrie.lookupString(string(s[p:]))
} }
func (s inputString) decomposeNFC(p int) uint16 { func (s inputString) charinfoNFKC(p int) (uint16, int) {
return nfcDecompTrie.lookupStringUnsafe(string(s[p:])) return nfkcTrie.lookupString(string(s[p:]))
}
func (s inputString) decomposeNFKC(p int) uint16 {
return nfkcDecompTrie.lookupStringUnsafe(string(s[p:]))
} }
func (s inputString) hangul(p int) rune { func (s inputString) hangul(p int) rune {
...@@ -84,16 +79,12 @@ func (s inputBytes) copySlice(buf []byte, b, e int) { ...@@ -84,16 +79,12 @@ func (s inputBytes) copySlice(buf []byte, b, e int) {
copy(buf, s[b:e]) copy(buf, s[b:e])
} }
func (s inputBytes) charinfo(p int) (uint16, int) { func (s inputBytes) charinfoNFC(p int) (uint16, int) {
return charInfoTrie.lookup(s[p:]) return nfcTrie.lookup(s[p:])
}
func (s inputBytes) decomposeNFC(p int) uint16 {
return nfcDecompTrie.lookupUnsafe(s[p:])
} }
func (s inputBytes) decomposeNFKC(p int) uint16 { func (s inputBytes) charinfoNFKC(p int) (uint16, int) {
return nfkcDecompTrie.lookupUnsafe(s[p:]) return nfkcTrie.lookup(s[p:])
} }
func (s inputBytes) hangul(p int) rune { func (s inputBytes) hangul(p int) rune {
......
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
// Normalization table generator. // Normalization table generator.
// Data read from the web. // Data read from the web.
// See forminfo.go for a description of the trie values associated with each rune.
package main package main
...@@ -17,6 +20,7 @@ import ( ...@@ -17,6 +20,7 @@ import (
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
"sort"
"strconv" "strconv"
"strings" "strings"
) )
...@@ -187,18 +191,14 @@ func (f FormInfo) String() string { ...@@ -187,18 +191,14 @@ func (f FormInfo) String() string {
fmt.Fprintf(buf, " cmbBackward: %v\n", f.combinesBackward) fmt.Fprintf(buf, " cmbBackward: %v\n", f.combinesBackward)
fmt.Fprintf(buf, " isOneWay: %v\n", f.isOneWay) fmt.Fprintf(buf, " isOneWay: %v\n", f.isOneWay)
fmt.Fprintf(buf, " inDecomp: %v\n", f.inDecomp) fmt.Fprintf(buf, " inDecomp: %v\n", f.inDecomp)
fmt.Fprintf(buf, " decomposition: %v\n", f.decomp) fmt.Fprintf(buf, " decomposition: %X\n", f.decomp)
fmt.Fprintf(buf, " expandedDecomp: %v\n", f.expandedDecomp) fmt.Fprintf(buf, " expandedDecomp: %X\n", f.expandedDecomp)
return buf.String() return buf.String()
} }
type Decomposition []rune type Decomposition []rune
func (d Decomposition) String() string {
return fmt.Sprintf("%.4X", d)
}
func openReader(file string) (input io.ReadCloser) { func openReader(file string) (input io.ReadCloser) {
if *localFiles { if *localFiles {
f, err := os.Open(file) f, err := os.Open(file)
...@@ -571,80 +571,121 @@ func makeEntry(f *FormInfo) uint16 { ...@@ -571,80 +571,121 @@ func makeEntry(f *FormInfo) uint16 {
return e return e
} }
// Bits // decompSet keeps track of unique decompositions, grouped by whether
// 0..8: CCC // the decomposition is followed by a trailing and/or leading CCC.
// 9..12: NF(C|D) qc bits. type decompSet [4]map[string]bool
// 13..16: NFK(C|D) qc bits.
func makeCharInfo(c Char) uint16 { func makeDecompSet() decompSet {
e := makeEntry(&c.forms[FCompatibility]) m := decompSet{}
e = e<<4 | makeEntry(&c.forms[FCanonical]) for i, _ := range m {
e = e<<8 | uint16(c.ccc) m[i] = make(map[string]bool)
return e }
return m
}
func (m *decompSet) insert(key int, s string) {
m[key][s] = true
} }
func printCharInfoTables() int { func printCharInfoTables() int {
// Quick Check + CCC trie. mkstr := func(r rune, f *FormInfo) (int, string) {
t := newNode() d := f.expandedDecomp
for i, char := range chars { s := string([]rune(d))
v := makeCharInfo(char) if max := 1 << 6; len(s) >= max {
if v != 0 { const msg = "%U: too many bytes in decomposition: %d >= %d"
t.insert(rune(i), v) logger.Fatalf(msg, r, len(s), max)
}
head := uint8(len(s))
if f.quickCheck[MComposed] != QCYes {
head |= 0x40
}
if f.combinesForward {
head |= 0x80
}
s = string([]byte{head}) + s
lccc := ccc(d[0])
tccc := ccc(d[len(d)-1])
if tccc < lccc && lccc != 0 {
const msg = "%U: lccc (%d) must be <= tcc (%d)"
logger.Fatalf(msg, r, lccc, tccc)
}
index := 0
if tccc > 0 || lccc > 0 {
s += string([]byte{tccc})
index = 1
if lccc > 0 {
s += string([]byte{lccc})
index |= 2
}
} }
return index, s
} }
return t.printTables("charInfo")
}
func printDecompositionTables() int { decompSet := makeDecompSet()
decompositions := bytes.NewBuffer(make([]byte, 0, 10000))
size := 0
// Map decompositions
positionMap := make(map[string]uint16)
// Store the uniqued decompositions in a byte buffer, // Store the uniqued decompositions in a byte buffer,
// preceded by their byte length. // preceded by their byte length.
for _, c := range chars { for _, c := range chars {
for f := 0; f < 2; f++ { for _, f := range c.forms {
d := c.forms[f].expandedDecomp if len(f.expandedDecomp) == 0 {
s := string([]rune(d)) continue
if _, ok := positionMap[s]; !ok {
p := decompositions.Len()
decompositions.WriteByte(uint8(len(s)))
decompositions.WriteString(s)
positionMap[s] = uint16(p)
} }
if f.combinesBackward {
logger.Fatalf("%U: combinesBackward and decompose", c.codePoint)
}
index, s := mkstr(c.codePoint, &f)
decompSet.insert(index, s)
} }
} }
decompositions := bytes.NewBuffer(make([]byte, 0, 10000))
size := 0
positionMap := make(map[string]uint16)
decompositions.WriteString("\000")
cname := []string{"firstCCC", "firstLeadingCCC", "", "lastDecomp"}
fmt.Println("const (")
for i, m := range decompSet {
sa := []string{}
for s, _ := range m {
sa = append(sa, s)
}
sort.Strings(sa)
for _, s := range sa {
p := decompositions.Len()
decompositions.WriteString(s)
positionMap[s] = uint16(p)
}
if cname[i] != "" {
fmt.Printf("%s = 0x%X\n", cname[i], decompositions.Len())
}
}
fmt.Println("maxDecomp = 0x8000")
fmt.Println(")")
b := decompositions.Bytes() b := decompositions.Bytes()
printBytes(b, "decomps") printBytes(b, "decomps")
size += len(b) size += len(b)
nfcT := newNode() varnames := []string{"nfc", "nfkc"}
nfkcT := newNode() for i := 0; i < FNumberOfFormTypes; i++ {
for i, c := range chars { trie := newNode()
d := c.forms[FCanonical].expandedDecomp for r, c := range chars {
if len(d) != 0 { f := c.forms[i]
nfcT.insert(rune(i), positionMap[string([]rune(d))]) d := f.expandedDecomp
if ccc(c.codePoint) != ccc(d[0]) { if len(d) != 0 {
// We assume the lead ccc of a decomposition is !=0 in this case. _, key := mkstr(c.codePoint, &f)
if ccc(d[0]) == 0 { trie.insert(rune(r), positionMap[key])
logger.Fatal("Expected differing CCC to be non-zero.") if c.ccc != ccc(d[0]) {
} // We assume the lead ccc of a decomposition !=0 in this case.
} if ccc(d[0]) == 0 {
} logger.Fatalf("Expected leading CCC to be non-zero; ccc is %d", c.ccc)
d = c.forms[FCompatibility].expandedDecomp }
if len(d) != 0 {
nfkcT.insert(rune(i), positionMap[string([]rune(d))])
if ccc(c.codePoint) != ccc(d[0]) {
// We assume the lead ccc of a decomposition is !=0 in this case.
if ccc(d[0]) == 0 {
logger.Fatal("Expected differing CCC to be non-zero.")
} }
} else if v := makeEntry(&f)<<8 | uint16(c.ccc); v != 0 {
trie.insert(c.codePoint, 0x8000|v)
} }
} }
size += trie.printTables(varnames[i])
} }
size += nfcT.printTables("nfcDecomp")
size += nfkcT.printTables("nfkcDecomp")
return size return size
} }
...@@ -687,15 +728,15 @@ func makeTables() { ...@@ -687,15 +728,15 @@ func makeTables() {
} }
list := strings.Split(*tablelist, ",") list := strings.Split(*tablelist, ",")
if *tablelist == "all" { if *tablelist == "all" {
list = []string{"decomp", "recomp", "info"} list = []string{"recomp", "info"}
} }
fmt.Printf(fileHeader, *tablelist, *url) fmt.Printf(fileHeader, *tablelist, *url)
fmt.Println("// Version is the Unicode edition from which the tables are derived.") fmt.Println("// Version is the Unicode edition from which the tables are derived.")
fmt.Printf("const Version = %q\n\n", version()) fmt.Printf("const Version = %q\n\n", version())
if contains(list, "decomp") { if contains(list, "info") {
size += printDecompositionTables() size += printCharInfoTables()
} }
if contains(list, "recomp") { if contains(list, "recomp") {
...@@ -730,9 +771,6 @@ func makeTables() { ...@@ -730,9 +771,6 @@ func makeTables() {
fmt.Printf("}\n\n") fmt.Printf("}\n\n")
} }
if contains(list, "info") {
size += printCharInfoTables()
}
fmt.Printf("// Total size of tables: %dKB (%d bytes)\n", (size+512)/1024, size) fmt.Printf("// Total size of tables: %dKB (%d bytes)\n", (size+512)/1024, size)
} }
...@@ -761,6 +799,11 @@ func verifyComputed() { ...@@ -761,6 +799,11 @@ func verifyComputed() {
log.Fatalf("%U: NF*C must be maybe if combinesBackward", i) log.Fatalf("%U: NF*C must be maybe if combinesBackward", i)
} }
} }
nfc := c.forms[FCanonical]
nfkc := c.forms[FCompatibility]
if nfc.combinesBackward != nfkc.combinesBackward {
logger.Fatalf("%U: Cannot combine combinesBackward\n", c.codePoint)
}
} }
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
// Generate test data for trie code. // Generate test data for trie code.
package main package main
......
...@@ -448,7 +448,7 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte { ...@@ -448,7 +448,7 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte {
} }
// Check that decomposition doesn't result in overflow. // Check that decomposition doesn't result in overflow.
if info.hasDecomposition() { if info.hasDecomposition() {
dcomp := rb.f.decompose(inputBytes(buf), p-int(info.size)) dcomp := info.decomposition()
for i := 0; i < len(dcomp); { for i := 0; i < len(dcomp); {
inf := rb.f.info(inputBytes(dcomp), i) inf := rb.f.info(inputBytes(dcomp), i)
i += int(inf.size) i += int(inf.size)
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
package main package main
import ( import (
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
// Trie table generator. // Trie table generator.
// Used by make*tables tools to generate a go file with trie data structures // Used by make*tables tools to generate a go file with trie data structures
// for mapping UTF-8 to a 16-bit value. All but the last byte in a UTF-8 byte // for mapping UTF-8 to a 16-bit value. All but the last byte in a UTF-8 byte
......
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd linux netbsd openbsd
// Package signal implements operating system-independent signal handling.
package signal
import (
"os"
"runtime"
)
// Incoming is the global signal channel.
// All signals received by the program will be delivered to this channel.
var Incoming <-chan os.Signal
func process(ch chan<- os.Signal) {
for {
var mask uint32 = runtime.Sigrecv()
for sig := uint(0); sig < 32; sig++ {
if mask&(1<<sig) != 0 {
ch <- os.UnixSignal(sig)
}
}
}
}
func init() {
runtime.Siginit()
ch := make(chan os.Signal) // Done here so Incoming can have type <-chan Signal
Incoming = ch
go process(ch)
}
// BUG(rsc): This package is unavailable on Plan 9 and Windows.
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd linux netbsd openbsd
package signal
import (
"os"
"syscall"
"testing"
)
const sighup = os.UnixSignal(syscall.SIGHUP)
func TestSignal(t *testing.T) {
// Send this process a SIGHUP.
syscall.Syscall(syscall.SYS_KILL, uintptr(syscall.Getpid()), syscall.SIGHUP, 0)
if sig := (<-Incoming).(os.UnixSignal); sig != sighup {
t.Errorf("signal was %v, want %v", sig, sighup)
}
}
...@@ -17,14 +17,14 @@ import ( ...@@ -17,14 +17,14 @@ import (
const debug = false const debug = false
type checker struct { type checker struct {
fset *token.FileSet fset *token.FileSet
scanner.ErrorVector errors scanner.ErrorList
types map[ast.Expr]Type types map[ast.Expr]Type
} }
func (c *checker) errorf(pos token.Pos, format string, args ...interface{}) string { func (c *checker) errorf(pos token.Pos, format string, args ...interface{}) string {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
c.Error(c.fset.Position(pos), msg) c.errors.Add(c.fset.Position(pos), msg)
return msg return msg
} }
...@@ -221,5 +221,6 @@ func Check(fset *token.FileSet, pkg *ast.Package) (types map[ast.Expr]Type, err ...@@ -221,5 +221,6 @@ func Check(fset *token.FileSet, pkg *ast.Package) (types map[ast.Expr]Type, err
c.checkObj(obj, false) c.checkObj(obj, false)
} }
return c.types, c.GetError(scanner.NoMultiples) c.errors.RemoveMultiples()
return c.types, c.errors.Err()
} }
...@@ -11,12 +11,12 @@ import ( ...@@ -11,12 +11,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/token" "go/token"
"io" "io"
"math/big" "math/big"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strconv" "strconv"
"text/scanner" "text/scanner"
) )
...@@ -24,7 +24,6 @@ import ( ...@@ -24,7 +24,6 @@ import (
const trace = false // set to true for debugging const trace = false // set to true for debugging
var ( var (
pkgRoot = filepath.Join(runtime.GOROOT(), "pkg", runtime.GOOS+"_"+runtime.GOARCH)
pkgExts = [...]string{".a", ".5", ".6", ".8"} pkgExts = [...]string{".a", ".5", ".6", ".8"}
) )
...@@ -39,8 +38,12 @@ func findPkg(path string) (filename, id string) { ...@@ -39,8 +38,12 @@ func findPkg(path string) (filename, id string) {
var noext string var noext string
switch path[0] { switch path[0] {
default: default:
// "x" -> "$GOROOT/pkg/$GOOS_$GOARCH/x.ext", "x" // "x" -> "$GOPATH/pkg/$GOOS_$GOARCH/x.ext", "x"
noext = filepath.Join(pkgRoot, path) tree, pkg, err := build.FindTree(path)
if err != nil {
return
}
noext = filepath.Join(tree.PkgDir(), pkg)
case '.': case '.':
// "./x" -> "/this/directory/x.ext", "/this/directory/x" // "./x" -> "/this/directory/x.ext", "/this/directory/x"
......
...@@ -6,7 +6,9 @@ package types ...@@ -6,7 +6,9 @@ package types
import ( import (
"go/ast" "go/ast"
"go/build"
"io/ioutil" "io/ioutil"
"os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
...@@ -31,7 +33,7 @@ func init() { ...@@ -31,7 +33,7 @@ func init() {
gcPath = gcName gcPath = gcName
return return
} }
gcPath = filepath.Join(runtime.GOROOT(), "/bin/tool/", gcName) gcPath = filepath.Join(build.ToolDir, gcName)
} }
func compile(t *testing.T, dirname, filename string) { func compile(t *testing.T, dirname, filename string) {
...@@ -61,7 +63,7 @@ func testPath(t *testing.T, path string) bool { ...@@ -61,7 +63,7 @@ func testPath(t *testing.T, path string) bool {
const maxTime = 3 * time.Second const maxTime = 3 * time.Second
func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) { func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) {
dirname := filepath.Join(pkgRoot, dir) dirname := filepath.Join(runtime.GOROOT(), "pkg", runtime.GOOS+"_"+runtime.GOARCH, dir)
list, err := ioutil.ReadDir(dirname) list, err := ioutil.ReadDir(dirname)
if err != nil { if err != nil {
t.Errorf("testDir(%s): %s", dirname, err) t.Errorf("testDir(%s): %s", dirname, err)
...@@ -90,6 +92,13 @@ func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) { ...@@ -90,6 +92,13 @@ func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) {
} }
func TestGcImport(t *testing.T) { func TestGcImport(t *testing.T) {
// On cross-compile builds, the path will not exist.
// Need to use GOHOSTOS, which is not available.
if _, err := os.Stat(gcPath); err != nil {
t.Logf("skipping test: %v", err)
return
}
compile(t, "testdata", "exports.go") compile(t, "testdata", "exports.go")
nimports := 0 nimports := 0
......
...@@ -423,6 +423,7 @@ var fmttests = []struct { ...@@ -423,6 +423,7 @@ var fmttests = []struct {
{"p0=%p", new(int), "p0=0xPTR"}, {"p0=%p", new(int), "p0=0xPTR"},
{"p1=%s", &pValue, "p1=String(p)"}, // String method... {"p1=%s", &pValue, "p1=String(p)"}, // String method...
{"p2=%p", &pValue, "p2=0xPTR"}, // ... not called with %p {"p2=%p", &pValue, "p2=0xPTR"}, // ... not called with %p
{"p3=%p", (*int)(nil), "p3=0x0"},
{"p4=%#p", new(int), "p4=PTR"}, {"p4=%#p", new(int), "p4=PTR"},
// %p on non-pointers // %p on non-pointers
...@@ -431,6 +432,14 @@ var fmttests = []struct { ...@@ -431,6 +432,14 @@ var fmttests = []struct {
{"%p", make([]int, 1), "0xPTR"}, {"%p", make([]int, 1), "0xPTR"},
{"%p", 27, "%!p(int=27)"}, // not a pointer at all {"%p", 27, "%!p(int=27)"}, // not a pointer at all
// %q on pointers
{"%q", (*int)(nil), "%!q(*int=<nil>)"},
{"%q", new(int), "%!q(*int=0xPTR)"},
// %v on pointers formats 0 as <nil>
{"%v", (*int)(nil), "<nil>"},
{"%v", new(int), "0xPTR"},
// %d on Stringer should give integer if possible // %d on Stringer should give integer if possible
{"%s", time.Time{}.Month(), "January"}, {"%s", time.Time{}.Month(), "January"},
{"%d", time.Time{}.Month(), "1"}, {"%d", time.Time{}.Month(), "1"},
......
...@@ -553,6 +553,14 @@ func (p *pp) fmtBytes(v []byte, verb rune, goSyntax bool, depth int) { ...@@ -553,6 +553,14 @@ func (p *pp) fmtBytes(v []byte, verb rune, goSyntax bool, depth int) {
} }
func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) { func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) {
switch verb {
case 'p', 'v', 'b', 'd', 'o', 'x', 'X':
// ok
default:
p.badVerb(verb)
return
}
var u uintptr var u uintptr
switch value.Kind() { switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
...@@ -561,6 +569,7 @@ func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) { ...@@ -561,6 +569,7 @@ func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) {
p.badVerb(verb) p.badVerb(verb)
return return
} }
if goSyntax { if goSyntax {
p.add('(') p.add('(')
p.buf.WriteString(value.Type().String()) p.buf.WriteString(value.Type().String())
...@@ -572,6 +581,8 @@ func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) { ...@@ -572,6 +581,8 @@ func (p *pp) fmtPointer(value reflect.Value, verb rune, goSyntax bool) {
p.fmt0x64(uint64(u), true) p.fmt0x64(uint64(u), true)
} }
p.add(')') p.add(')')
} else if verb == 'v' && u == 0 {
p.buf.Write(nilAngleBytes)
} else { } else {
p.fmt0x64(uint64(u), !p.fmt.sharp) p.fmt0x64(uint64(u), !p.fmt.sharp)
} }
...@@ -929,24 +940,7 @@ BigSwitch: ...@@ -929,24 +940,7 @@ BigSwitch:
break BigSwitch break BigSwitch
} }
} }
if goSyntax { fallthrough
p.buf.WriteByte('(')
p.buf.WriteString(value.Type().String())
p.buf.WriteByte(')')
p.buf.WriteByte('(')
if v == 0 {
p.buf.Write(nilBytes)
} else {
p.fmt0x64(uint64(v), true)
}
p.buf.WriteByte(')')
break
}
if v == 0 {
p.buf.Write(nilAngleBytes)
break
}
p.fmt0x64(uint64(v), true)
case reflect.Chan, reflect.Func, reflect.UnsafePointer: case reflect.Chan, reflect.Func, reflect.UnsafePointer:
p.fmtPointer(value, verb, goSyntax) p.fmtPointer(value, verb, goSyntax)
default: default:
......
...@@ -512,7 +512,7 @@ func (s *ss) scanBool(verb rune) bool { ...@@ -512,7 +512,7 @@ func (s *ss) scanBool(verb rune) bool {
} }
return true return true
case 'f', 'F': case 'f', 'F':
if s.accept("aL") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) { if s.accept("aA") && (!s.accept("lL") || !s.accept("sS") || !s.accept("eE")) {
s.error(boolError) s.error(boolError)
} }
return false return false
......
...@@ -317,6 +317,7 @@ var overflowTests = []ScanTest{ ...@@ -317,6 +317,7 @@ var overflowTests = []ScanTest{
{"(1-1e500i)", &complex128Val, 0}, {"(1-1e500i)", &complex128Val, 0},
} }
var truth bool
var i, j, k int var i, j, k int
var f float64 var f float64
var s, t string var s, t string
...@@ -350,6 +351,9 @@ var multiTests = []ScanfMultiTest{ ...@@ -350,6 +351,9 @@ var multiTests = []ScanfMultiTest{
// Bad UTF-8: should see every byte. // Bad UTF-8: should see every byte.
{"%c%c%c", "\xc2X\xc2", args(&r1, &r2, &r3), args(utf8.RuneError, 'X', utf8.RuneError), ""}, {"%c%c%c", "\xc2X\xc2", args(&r1, &r2, &r3), args(utf8.RuneError, 'X', utf8.RuneError), ""},
// Fixed bugs
{"%v%v", "FALSE23", args(&truth, &i), args(false, 23), ""},
} }
func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}) (int, error)) { func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}) (int, error)) {
......
...@@ -14,12 +14,12 @@ import ( ...@@ -14,12 +14,12 @@ import (
) )
type pkgBuilder struct { type pkgBuilder struct {
scanner.ErrorVector fset *token.FileSet
fset *token.FileSet errors scanner.ErrorList
} }
func (p *pkgBuilder) error(pos token.Pos, msg string) { func (p *pkgBuilder) error(pos token.Pos, msg string) {
p.Error(p.fset.Position(pos), msg) p.errors.Add(p.fset.Position(pos), msg)
} }
func (p *pkgBuilder) errorf(pos token.Pos, format string, args ...interface{}) { func (p *pkgBuilder) errorf(pos token.Pos, format string, args ...interface{}) {
...@@ -169,5 +169,6 @@ func NewPackage(fset *token.FileSet, files map[string]*File, importer Importer, ...@@ -169,5 +169,6 @@ func NewPackage(fset *token.FileSet, files map[string]*File, importer Importer,
pkgScope.Outer = universe // reset universe scope pkgScope.Outer = universe // reset universe scope
} }
return &Package{pkgName, pkgScope, imports, files}, p.GetError(scanner.Sorted) p.errors.Sort()
return &Package{pkgName, pkgScope, imports, files}, p.errors.Err()
} }
...@@ -25,10 +25,11 @@ import ( ...@@ -25,10 +25,11 @@ import (
// A Context specifies the supporting context for a build. // A Context specifies the supporting context for a build.
type Context struct { type Context struct {
GOARCH string // target architecture GOARCH string // target architecture
GOOS string // target operating system GOOS string // target operating system
CgoEnabled bool // whether cgo can be used CgoEnabled bool // whether cgo can be used
BuildTags []string // additional tags to recognize in +build lines BuildTags []string // additional tags to recognize in +build lines
UseAllFiles bool // use files regardless of +build lines, file names
// By default, ScanDir uses the operating system's // By default, ScanDir uses the operating system's
// file system calls to read directories and files. // file system calls to read directories and files.
...@@ -225,6 +226,7 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) { ...@@ -225,6 +226,7 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) {
var Sfiles []string // files with ".S" (capital S) var Sfiles []string // files with ".S" (capital S)
var di DirInfo var di DirInfo
var firstFile string
imported := make(map[string][]token.Position) imported := make(map[string][]token.Position)
testImported := make(map[string][]token.Position) testImported := make(map[string][]token.Position)
fset := token.NewFileSet() fset := token.NewFileSet()
...@@ -237,7 +239,7 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) { ...@@ -237,7 +239,7 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) {
strings.HasPrefix(name, ".") { strings.HasPrefix(name, ".") {
continue continue
} }
if !ctxt.goodOSArchFile(name) { if !ctxt.UseAllFiles && !ctxt.goodOSArchFile(name) {
continue continue
} }
...@@ -250,12 +252,13 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) { ...@@ -250,12 +252,13 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) {
continue continue
} }
// Look for +build comments to accept or reject the file.
filename, data, err := ctxt.readFile(dir, name) filename, data, err := ctxt.readFile(dir, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !ctxt.shouldBuild(data) {
// Look for +build comments to accept or reject the file.
if !ctxt.UseAllFiles && !ctxt.shouldBuild(data) {
continue continue
} }
...@@ -281,9 +284,6 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) { ...@@ -281,9 +284,6 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) {
} }
pkg := string(pf.Name.Name) pkg := string(pf.Name.Name)
if pkg == "main" && di.Package != "" && di.Package != "main" {
continue
}
if pkg == "documentation" { if pkg == "documentation" {
continue continue
} }
...@@ -293,15 +293,11 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) { ...@@ -293,15 +293,11 @@ func (ctxt *Context) ScanDir(dir string) (info *DirInfo, err error) {
pkg = pkg[:len(pkg)-len("_test")] pkg = pkg[:len(pkg)-len("_test")]
} }
if pkg != di.Package && di.Package == "main" {
// Found non-main package but was recording
// information about package main. Reset.
di = DirInfo{}
}
if di.Package == "" { if di.Package == "" {
di.Package = pkg di.Package = pkg
firstFile = name
} else if pkg != di.Package { } else if pkg != di.Package {
return nil, fmt.Errorf("%s: found packages %s and %s", dir, pkg, di.Package) return nil, fmt.Errorf("%s: found packages %s (%s) and %s (%s)", dir, di.Package, firstFile, pkg, name)
} }
if pf.Doc != nil { if pf.Doc != nil {
if di.PackageComment != nil { if di.PackageComment != nil {
......
...@@ -12,6 +12,9 @@ import ( ...@@ -12,6 +12,9 @@ import (
"runtime" "runtime"
) )
// ToolDir is the directory containing build tools.
var ToolDir = filepath.Join(runtime.GOROOT(), "pkg/tool/"+runtime.GOOS+"_"+runtime.GOARCH)
// Path is a validated list of Trees derived from $GOROOT and $GOPATH at init. // Path is a validated list of Trees derived from $GOROOT and $GOPATH at init.
var Path []*Tree var Path []*Tree
......
...@@ -14,12 +14,14 @@ import ( ...@@ -14,12 +14,14 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"testing" "testing"
"text/template" "text/template"
) )
var update = flag.Bool("update", false, "update golden (.out) files") var update = flag.Bool("update", false, "update golden (.out) files")
var files = flag.String("files", "", "consider only Go test files matching this regular expression")
const dataDir = "testdata" const dataDir = "testdata"
...@@ -66,14 +68,26 @@ type bundle struct { ...@@ -66,14 +68,26 @@ type bundle struct {
} }
func test(t *testing.T, mode Mode) { func test(t *testing.T, mode Mode) {
// get all packages // determine file filter
filter := isGoFile
if *files != "" {
rx, err := regexp.Compile(*files)
if err != nil {
t.Fatal(err)
}
filter = func(fi os.FileInfo) bool {
return isGoFile(fi) && rx.MatchString(fi.Name())
}
}
// get packages
fset := token.NewFileSet() fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, dataDir, isGoFile, parser.ParseComments) pkgs, err := parser.ParseDir(fset, dataDir, filter, parser.ParseComments)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// test all packages // test packages
for _, pkg := range pkgs { for _, pkg := range pkgs {
importpath := dataDir + "/" + pkg.Name importpath := dataDir + "/" + pkg.Name
doc := New(pkg, importpath, mode) doc := New(pkg, importpath, mode)
......
...@@ -9,6 +9,7 @@ package doc ...@@ -9,6 +9,7 @@ package doc
import ( import (
"go/ast" "go/ast"
"go/printer" "go/printer"
"go/token"
"strings" "strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
...@@ -21,28 +22,47 @@ type Example struct { ...@@ -21,28 +22,47 @@ type Example struct {
} }
func Examples(pkg *ast.Package) []*Example { func Examples(pkg *ast.Package) []*Example {
var examples []*Example var list []*Example
for _, src := range pkg.Files { for _, file := range pkg.Files {
for _, decl := range src.Decls { hasTests := false // file contains tests or benchmarks
numDecl := 0 // number of non-import declarations in the file
var flist []*Example
for _, decl := range file.Decls {
if g, ok := decl.(*ast.GenDecl); ok && g.Tok != token.IMPORT {
numDecl++
continue
}
f, ok := decl.(*ast.FuncDecl) f, ok := decl.(*ast.FuncDecl)
if !ok { if !ok {
continue continue
} }
numDecl++
name := f.Name.Name name := f.Name.Name
if isTest(name, "Test") || isTest(name, "Benchmark") {
hasTests = true
continue
}
if !isTest(name, "Example") { if !isTest(name, "Example") {
continue continue
} }
examples = append(examples, &Example{ flist = append(flist, &Example{
Name: name[len("Example"):], Name: name[len("Example"):],
Body: &printer.CommentedNode{ Body: &printer.CommentedNode{
Node: f.Body, Node: f.Body,
Comments: src.Comments, Comments: file.Comments,
}, },
Output: f.Doc.Text(), Output: f.Doc.Text(),
}) })
} }
if !hasTests && numDecl > 1 && len(flist) == 1 {
// If this file only has one example function, some
// other top-level declarations, and no tests or
// benchmarks, use the whole file as the example.
flist[0].Body.Node = file
}
list = append(list, flist...)
} }
return examples return list
} }
// isTest tells whether name looks like a test, example, or benchmark. // isTest tells whether name looks like a test, example, or benchmark.
......
...@@ -22,12 +22,38 @@ func filterIdentList(list []*ast.Ident) []*ast.Ident { ...@@ -22,12 +22,38 @@ func filterIdentList(list []*ast.Ident) []*ast.Ident {
return list[0:j] return list[0:j]
} }
// removeErrorField removes anonymous fields named "error" from an interface.
// This is called when "error" has been determined to be a local name,
// not the predeclared type.
//
func removeErrorField(ityp *ast.InterfaceType) {
list := ityp.Methods.List // we know that ityp.Methods != nil
j := 0
for _, field := range list {
keepField := true
if n := len(field.Names); n == 0 {
// anonymous field
if fname, _ := baseTypeName(field.Type); fname == "error" {
keepField = false
}
}
if keepField {
list[j] = field
j++
}
}
if j < len(list) {
ityp.Incomplete = true
}
ityp.Methods.List = list[0:j]
}
// filterFieldList removes unexported fields (field names) from the field list // filterFieldList removes unexported fields (field names) from the field list
// in place and returns true if fields were removed. Anonymous fields are // in place and returns true if fields were removed. Anonymous fields are
// recorded with the parent type. filterType is called with the types of // recorded with the parent type. filterType is called with the types of
// all remaining fields. // all remaining fields.
// //
func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList) (removedFields bool) { func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
if fields == nil { if fields == nil {
return return
} }
...@@ -37,9 +63,15 @@ func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList) (remo ...@@ -37,9 +63,15 @@ func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList) (remo
keepField := false keepField := false
if n := len(field.Names); n == 0 { if n := len(field.Names); n == 0 {
// anonymous field // anonymous field
name := r.recordAnonymousField(parent, field.Type) fname := r.recordAnonymousField(parent, field.Type)
if ast.IsExported(name) { if ast.IsExported(fname) {
keepField = true
} else if ityp != nil && fname == "error" {
// possibly the predeclared error interface; keep
// it for now but remember this interface so that
// it can be fixed if error is also defined locally
keepField = true keepField = true
r.remember(ityp)
} }
} else { } else {
field.Names = filterIdentList(field.Names) field.Names = filterIdentList(field.Names)
...@@ -86,14 +118,14 @@ func (r *reader) filterType(parent *namedType, typ ast.Expr) { ...@@ -86,14 +118,14 @@ func (r *reader) filterType(parent *namedType, typ ast.Expr) {
case *ast.ArrayType: case *ast.ArrayType:
r.filterType(nil, t.Elt) r.filterType(nil, t.Elt)
case *ast.StructType: case *ast.StructType:
if r.filterFieldList(parent, t.Fields) { if r.filterFieldList(parent, t.Fields, nil) {
t.Incomplete = true t.Incomplete = true
} }
case *ast.FuncType: case *ast.FuncType:
r.filterParamList(t.Params) r.filterParamList(t.Params)
r.filterParamList(t.Results) r.filterParamList(t.Results)
case *ast.InterfaceType: case *ast.InterfaceType:
if r.filterFieldList(parent, t.Methods) { if r.filterFieldList(parent, t.Methods, t) {
t.Incomplete = true t.Incomplete = true
} }
case *ast.MapType: case *ast.MapType:
...@@ -116,9 +148,12 @@ func (r *reader) filterSpec(spec ast.Spec) bool { ...@@ -116,9 +148,12 @@ func (r *reader) filterSpec(spec ast.Spec) bool {
return true return true
} }
case *ast.TypeSpec: case *ast.TypeSpec:
if ast.IsExported(s.Name.Name) { if name := s.Name.Name; ast.IsExported(name) {
r.filterType(r.lookupType(s.Name.Name), s.Type) r.filterType(r.lookupType(s.Name.Name), s.Type)
return true return true
} else if name == "error" {
// special case: remember that error is declared locally
r.errorDecl = true
} }
} }
return false return false
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build ignore
/* /*
The headscan command extracts comment headings from package files; The headscan command extracts comment headings from package files;
it is used to detect false positives which may require an adjustment it is used to detect false positives which may require an adjustment
......
...@@ -17,7 +17,7 @@ import ( ...@@ -17,7 +17,7 @@ import (
// //
// Internally, we treat functions like methods and collect them in method sets. // Internally, we treat functions like methods and collect them in method sets.
// methodSet describes a set of methods. Entries where Decl == nil are conflict // A methodSet describes a set of methods. Entries where Decl == nil are conflict
// entries (more then one method with the same name at the same embedding level). // entries (more then one method with the same name at the same embedding level).
// //
type methodSet map[string]*Func type methodSet map[string]*Func
...@@ -110,6 +110,9 @@ func baseTypeName(x ast.Expr) (name string, imported bool) { ...@@ -110,6 +110,9 @@ func baseTypeName(x ast.Expr) (name string, imported bool) {
return return
} }
// An embeddedSet describes a set of embedded types.
type embeddedSet map[*namedType]bool
// A namedType represents a named unqualified (package local, or possibly // A namedType represents a named unqualified (package local, or possibly
// predeclared) type. The namedType for a type name is always found via // predeclared) type. The namedType for a type name is always found via
// reader.lookupType. // reader.lookupType.
...@@ -119,9 +122,9 @@ type namedType struct { ...@@ -119,9 +122,9 @@ type namedType struct {
name string // type name name string // type name
decl *ast.GenDecl // nil if declaration hasn't been seen yet decl *ast.GenDecl // nil if declaration hasn't been seen yet
isEmbedded bool // true if this type is embedded isEmbedded bool // true if this type is embedded
isStruct bool // true if this type is a struct isStruct bool // true if this type is a struct
embedded map[*namedType]bool // true if the embedded type is a pointer embedded embeddedSet // true if the embedded type is a pointer
// associated declarations // associated declarations
values []*Value // consts and vars values []*Value // consts and vars
...@@ -152,6 +155,10 @@ type reader struct { ...@@ -152,6 +155,10 @@ type reader struct {
values []*Value // consts and vars values []*Value // consts and vars
types map[string]*namedType types map[string]*namedType
funcs methodSet funcs methodSet
// support for package-local error type declarations
errorDecl bool // if set, type "error" was declared locally
fixlist []*ast.InterfaceType // list of interfaces containing anonymous field "error"
} }
func (r *reader) isVisible(name string) bool { func (r *reader) isVisible(name string) bool {
...@@ -173,7 +180,7 @@ func (r *reader) lookupType(name string) *namedType { ...@@ -173,7 +180,7 @@ func (r *reader) lookupType(name string) *namedType {
// type not found - add one without declaration // type not found - add one without declaration
typ := &namedType{ typ := &namedType{
name: name, name: name,
embedded: make(map[*namedType]bool), embedded: make(embeddedSet),
funcs: make(methodSet), funcs: make(methodSet),
methods: make(methodSet), methods: make(methodSet),
} }
...@@ -210,6 +217,10 @@ func (r *reader) readDoc(comment *ast.CommentGroup) { ...@@ -210,6 +217,10 @@ func (r *reader) readDoc(comment *ast.CommentGroup) {
r.doc += "\n" + text r.doc += "\n" + text
} }
func (r *reader) remember(typ *ast.InterfaceType) {
r.fixlist = append(r.fixlist, typ)
}
func specNames(specs []ast.Spec) []string { func specNames(specs []ast.Spec) []string {
names := make([]string, 0, len(specs)) // reasonable estimate names := make([]string, 0, len(specs)) // reasonable estimate
for _, s := range specs { for _, s := range specs {
...@@ -274,7 +285,7 @@ func (r *reader) readValue(decl *ast.GenDecl) { ...@@ -274,7 +285,7 @@ func (r *reader) readValue(decl *ast.GenDecl) {
// determine values list with which to associate the Value for this decl // determine values list with which to associate the Value for this decl
values := &r.values values := &r.values
const threshold = 0.75 const threshold = 0.75
if domName != "" && domFreq >= int(float64(len(decl.Specs))*threshold) { if domName != "" && r.isVisible(domName) && domFreq >= int(float64(len(decl.Specs))*threshold) {
// typed entries are sufficiently frequent // typed entries are sufficiently frequent
if typ := r.lookupType(domName); typ != nil { if typ := r.lookupType(domName); typ != nil {
values = &typ.values // associate with that type values = &typ.values // associate with that type
...@@ -315,7 +326,7 @@ func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) { ...@@ -315,7 +326,7 @@ func (r *reader) readType(decl *ast.GenDecl, spec *ast.TypeSpec) {
return // no name or blank name - ignore the type return // no name or blank name - ignore the type
} }
// A type should be added at most once, so info.decl // A type should be added at most once, so typ.decl
// should be nil - if it is not, simply overwrite it. // should be nil - if it is not, simply overwrite it.
typ.decl = decl typ.decl = decl
...@@ -543,7 +554,8 @@ func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int) ...@@ -543,7 +554,8 @@ func customizeRecv(f *Func, recvTypeName string, embeddedIsPtr bool, level int)
// collectEmbeddedMethods collects the embedded methods of typ in mset. // collectEmbeddedMethods collects the embedded methods of typ in mset.
// //
func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvTypeName string, embeddedIsPtr bool, level int) { func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvTypeName string, embeddedIsPtr bool, level int, visited embeddedSet) {
visited[typ] = true
for embedded, isPtr := range typ.embedded { for embedded, isPtr := range typ.embedded {
// Once an embedded type is embedded as a pointer type // Once an embedded type is embedded as a pointer type
// all embedded types in those types are treated like // all embedded types in those types are treated like
...@@ -557,8 +569,11 @@ func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvType ...@@ -557,8 +569,11 @@ func (r *reader) collectEmbeddedMethods(mset methodSet, typ *namedType, recvType
mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level)) mset.add(customizeRecv(m, recvTypeName, thisEmbeddedIsPtr, level))
} }
} }
r.collectEmbeddedMethods(mset, embedded, recvTypeName, thisEmbeddedIsPtr, level+1) if !visited[embedded] {
r.collectEmbeddedMethods(mset, embedded, recvTypeName, thisEmbeddedIsPtr, level+1, visited)
}
} }
delete(visited, typ)
} }
// computeMethodSets determines the actual method sets for each type encountered. // computeMethodSets determines the actual method sets for each type encountered.
...@@ -568,12 +583,19 @@ func (r *reader) computeMethodSets() { ...@@ -568,12 +583,19 @@ func (r *reader) computeMethodSets() {
// collect embedded methods for t // collect embedded methods for t
if t.isStruct { if t.isStruct {
// struct // struct
r.collectEmbeddedMethods(t.methods, t, t.name, false, 1) r.collectEmbeddedMethods(t.methods, t, t.name, false, 1, make(embeddedSet))
} else { } else {
// interface // interface
// TODO(gri) fix this // TODO(gri) fix this
} }
} }
// if error was declared locally, don't treat it as exported field anymore
if r.errorDecl {
for _, ityp := range r.fixlist {
removeErrorField(ityp)
}
}
} }
// cleanupTypes removes the association of functions and methods with // cleanupTypes removes the association of functions and methods with
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment