Commit adb0401d by Ian Lance Taylor

Update Go library to r60.

From-SVN: r178910
parent 5548ca35
...@@ -806,7 +806,7 @@ proc go-gc-tests { } { ...@@ -806,7 +806,7 @@ proc go-gc-tests { } {
$status $name $status $name
} else { } else {
verbose -log $comp_output verbose -log $comp_output
fali $name fail $name
} }
file delete $ofile1 $ofile2 $output_file file delete $ofile1 $ofile2 $output_file
set runtests $hold_runtests set runtests $hold_runtests
......
...@@ -38,7 +38,7 @@ func Listen(x, y string) (T, string) { ...@@ -38,7 +38,7 @@ func Listen(x, y string) (T, string) {
} }
func (t T) Addr() os.Error { func (t T) Addr() os.Error {
return os.ErrorString("stringer") return os.NewError("stringer")
} }
func (t T) Accept() (int, string) { func (t T) Accept() (int, string) {
...@@ -49,4 +49,3 @@ func Dial(x, y, z string) (int, string) { ...@@ -49,4 +49,3 @@ func Dial(x, y, z string) (int, string) {
global <- 1 global <- 1
return 0, "" return 0, ""
} }
...@@ -18,6 +18,7 @@ var chatty = flag.Bool("v", false, "chatty") ...@@ -18,6 +18,7 @@ var chatty = flag.Bool("v", false, "chatty")
var oldsys uint64 var oldsys uint64
func bigger() { func bigger() {
runtime.UpdateMemStats()
if st := runtime.MemStats; oldsys < st.Sys { if st := runtime.MemStats; oldsys < st.Sys {
oldsys = st.Sys oldsys = st.Sys
if *chatty { if *chatty {
...@@ -31,7 +32,7 @@ func bigger() { ...@@ -31,7 +32,7 @@ func bigger() {
} }
func main() { func main() {
runtime.GC() // clean up garbage from init runtime.GC() // clean up garbage from init
runtime.MemProfileRate = 0 // disable profiler runtime.MemProfileRate = 0 // disable profiler
runtime.MemStats.Alloc = 0 // ignore stacks runtime.MemStats.Alloc = 0 // ignore stacks
flag.Parse() flag.Parse()
...@@ -45,8 +46,10 @@ func main() { ...@@ -45,8 +46,10 @@ func main() {
panic("fail") panic("fail")
} }
b := runtime.Alloc(uintptr(j)) b := runtime.Alloc(uintptr(j))
runtime.UpdateMemStats()
during := runtime.MemStats.Alloc during := runtime.MemStats.Alloc
runtime.Free(b) runtime.Free(b)
runtime.UpdateMemStats()
if a := runtime.MemStats.Alloc; a != 0 { if a := runtime.MemStats.Alloc; a != 0 {
println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)") println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)")
panic("fail") panic("fail")
......
...@@ -42,6 +42,7 @@ func AllocAndFree(size, count int) { ...@@ -42,6 +42,7 @@ func AllocAndFree(size, count int) {
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d ...\n", size, count) fmt.Printf("size=%d count=%d ...\n", size, count)
} }
runtime.UpdateMemStats()
n1 := stats.Alloc n1 := stats.Alloc
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
b[i] = runtime.Alloc(uintptr(size)) b[i] = runtime.Alloc(uintptr(size))
...@@ -50,11 +51,13 @@ func AllocAndFree(size, count int) { ...@@ -50,11 +51,13 @@ func AllocAndFree(size, count int) {
println("lookup failed: got", base, n, "for", b[i]) println("lookup failed: got", base, n, "for", b[i])
panic("fail") panic("fail")
} }
if runtime.MemStats.Sys > 1e9 { runtime.UpdateMemStats()
if stats.Sys > 1e9 {
println("too much memory allocated") println("too much memory allocated")
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats()
n2 := stats.Alloc n2 := stats.Alloc
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats) fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats)
...@@ -72,6 +75,7 @@ func AllocAndFree(size, count int) { ...@@ -72,6 +75,7 @@ func AllocAndFree(size, count int) {
panic("fail") panic("fail")
} }
runtime.Free(b[i]) runtime.Free(b[i])
runtime.UpdateMemStats()
if stats.Alloc != uint64(alloc-n) { if stats.Alloc != uint64(alloc-n) {
println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n) println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n)
panic("fail") panic("fail")
...@@ -81,6 +85,7 @@ func AllocAndFree(size, count int) { ...@@ -81,6 +85,7 @@ func AllocAndFree(size, count int) {
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats()
n4 := stats.Alloc n4 := stats.Alloc
if *chatty { if *chatty {
......
aea0ba6e5935 504f4e9b079c
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -12,12 +12,24 @@ ...@@ -12,12 +12,24 @@
/* Define to 1 if you have the <inttypes.h> header file. */ /* Define to 1 if you have the <inttypes.h> header file. */
#undef HAVE_INTTYPES_H #undef HAVE_INTTYPES_H
/* Define to 1 if you have the <linux/filter.h> header file. */
#undef HAVE_LINUX_FILTER_H
/* Define to 1 if you have the <linux/netlink.h> header file. */
#undef HAVE_LINUX_NETLINK_H
/* Define to 1 if you have the <linux/rtnetlink.h> header file. */
#undef HAVE_LINUX_RTNETLINK_H
/* Define to 1 if you have the <memory.h> header file. */ /* Define to 1 if you have the <memory.h> header file. */
#undef HAVE_MEMORY_H #undef HAVE_MEMORY_H
/* Define to 1 if you have the `mincore' function. */ /* Define to 1 if you have the `mincore' function. */
#undef HAVE_MINCORE #undef HAVE_MINCORE
/* Define to 1 if you have the <net/if.h> header file. */
#undef HAVE_NET_IF_H
/* Define to 1 if the system has the type `off64_t'. */ /* Define to 1 if the system has the type `off64_t'. */
#undef HAVE_OFF64_T #undef HAVE_OFF64_T
...@@ -71,6 +83,9 @@ ...@@ -71,6 +83,9 @@
/* Define to 1 if you have the <sys/select.h> header file. */ /* Define to 1 if you have the <sys/select.h> header file. */
#undef HAVE_SYS_SELECT_H #undef HAVE_SYS_SELECT_H
/* Define to 1 if you have the <sys/socket.h> header file. */
#undef HAVE_SYS_SOCKET_H
/* Define to 1 if you have the <sys/stat.h> header file. */ /* Define to 1 if you have the <sys/stat.h> header file. */
#undef HAVE_SYS_STAT_H #undef HAVE_SYS_STAT_H
......
...@@ -617,7 +617,6 @@ USING_SPLIT_STACK_FALSE ...@@ -617,7 +617,6 @@ USING_SPLIT_STACK_FALSE
USING_SPLIT_STACK_TRUE USING_SPLIT_STACK_TRUE
SPLIT_STACK SPLIT_STACK
OSCFLAGS OSCFLAGS
GO_DEBUG_PROC_REGS_OS_ARCH_FILE
GO_SYSCALLS_SYSCALL_OS_ARCH_FILE GO_SYSCALLS_SYSCALL_OS_ARCH_FILE
GOARCH GOARCH
LIBGO_IS_X86_64_FALSE LIBGO_IS_X86_64_FALSE
...@@ -10914,7 +10913,7 @@ else ...@@ -10914,7 +10913,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 10917 "configure" #line 10916 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
...@@ -11020,7 +11019,7 @@ else ...@@ -11020,7 +11019,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 11023 "configure" #line 11022 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
...@@ -13558,12 +13557,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then ...@@ -13558,12 +13557,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi fi
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
case "$target" in case "$target" in
mips-sgi-irix6.5*) mips-sgi-irix6.5*)
# IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct # IRIX 6 needs _XOPEN_SOURCE=500 for the XPG5 version of struct
...@@ -14252,7 +14245,7 @@ no) ...@@ -14252,7 +14245,7 @@ no)
;; ;;
esac esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
do : do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
...@@ -14266,6 +14259,26 @@ fi ...@@ -14266,6 +14259,26 @@ fi
done done
for ac_header in linux/filter.h linux/netlink.h linux/rtnetlink.h
do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_compile "$LINENO" "$ac_header" "$as_ac_Header" "#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
"
eval as_val=\$$as_ac_Header
if test "x$as_val" = x""yes; then :
cat >>confdefs.h <<_ACEOF
#define `$as_echo "HAVE_$ac_header" | $as_tr_cpp` 1
_ACEOF
fi
done
if test "$ac_cv_header_sys_mman_h" = yes; then if test "$ac_cv_header_sys_mman_h" = yes; then
HAVE_SYS_MMAN_H_TRUE= HAVE_SYS_MMAN_H_TRUE=
HAVE_SYS_MMAN_H_FALSE='#' HAVE_SYS_MMAN_H_FALSE='#'
......
...@@ -255,12 +255,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then ...@@ -255,12 +255,6 @@ if test -f ${srcdir}/syscalls/syscall_${GOOS}_${GOARCH}.go; then
fi fi
AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE) AC_SUBST(GO_SYSCALLS_SYSCALL_OS_ARCH_FILE)
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=
if test -f ${srcdir}/go/debug/proc/regs_${GOOS}_${GOARCH}.go; then
GO_DEBUG_PROC_REGS_OS_ARCH_FILE=go/debug/proc/regs_${GOOS}_${GOARCH}.go
fi
AC_SUBST(GO_DEBUG_PROC_REGS_OS_ARCH_FILE)
dnl Some targets need special flags to build sysinfo.go. dnl Some targets need special flags to build sysinfo.go.
case "$target" in case "$target" in
mips-sgi-irix6.5*) mips-sgi-irix6.5*)
...@@ -431,7 +425,14 @@ no) ...@@ -431,7 +425,14 @@ no)
;; ;;
esac esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h) AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
])
AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes) AM_CONDITIONAL(HAVE_SYS_MMAN_H, test "$ac_cv_header_sys_mman_h" = yes)
AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv) AC_CHECK_FUNCS(srandom random strerror_r strsignal wait4 mincore setenv)
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
var ( var (
HeaderError os.Error = os.ErrorString("invalid tar header") HeaderError = os.NewError("invalid tar header")
) )
// A Reader provides sequential access to the contents of a tar archive. // A Reader provides sequential access to the contents of a tar archive.
......
...@@ -178,7 +178,6 @@ func TestPartialRead(t *testing.T) { ...@@ -178,7 +178,6 @@ func TestPartialRead(t *testing.T) {
} }
} }
func TestIncrementalRead(t *testing.T) { func TestIncrementalRead(t *testing.T) {
test := gnuTarTest test := gnuTarTest
f, err := os.Open(test.file) f, err := os.Open(test.file)
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time"
) )
type ZipTest struct { type ZipTest struct {
...@@ -24,8 +25,19 @@ type ZipTestFile struct { ...@@ -24,8 +25,19 @@ type ZipTestFile struct {
Name string Name string
Content []byte // if blank, will attempt to compare against File Content []byte // if blank, will attempt to compare against File
File string // name of file to compare to (relative to testdata/) File string // name of file to compare to (relative to testdata/)
Mtime string // modified time in format "mm-dd-yy hh:mm:ss"
} }
// Caution: The Mtime values found for the test files should correspond to
// the values listed with unzip -l <zipfile>. However, the values
// listed by unzip appear to be off by some hours. When creating
// fresh test files and testing them, this issue is not present.
// The test files were created in Sydney, so there might be a time
// zone issue. The time zone information does have to be encoded
// somewhere, because otherwise unzip -l could not provide a different
// time from what the archive/zip package provides, but there appears
// to be no documentation about this.
var tests = []ZipTest{ var tests = []ZipTest{
{ {
Name: "test.zip", Name: "test.zip",
...@@ -34,10 +46,12 @@ var tests = []ZipTest{ ...@@ -34,10 +46,12 @@ var tests = []ZipTest{
{ {
Name: "test.txt", Name: "test.txt",
Content: []byte("This is a test text file.\n"), Content: []byte("This is a test text file.\n"),
Mtime: "09-05-10 12:12:02",
}, },
{ {
Name: "gophercolor16x16.png", Name: "gophercolor16x16.png",
File: "gophercolor16x16.png", File: "gophercolor16x16.png",
Mtime: "09-05-10 15:52:58",
}, },
}, },
}, },
...@@ -45,8 +59,9 @@ var tests = []ZipTest{ ...@@ -45,8 +59,9 @@ var tests = []ZipTest{
Name: "r.zip", Name: "r.zip",
File: []ZipTestFile{ File: []ZipTestFile{
{ {
Name: "r/r.zip", Name: "r/r.zip",
File: "r.zip", File: "r.zip",
Mtime: "03-04-10 00:24:16",
}, },
}, },
}, },
...@@ -58,6 +73,7 @@ var tests = []ZipTest{ ...@@ -58,6 +73,7 @@ var tests = []ZipTest{
{ {
Name: "filename", Name: "filename",
Content: []byte("This is a test textfile.\n"), Content: []byte("This is a test textfile.\n"),
Mtime: "02-02-11 13:06:20",
}, },
}, },
}, },
...@@ -136,18 +152,36 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) { ...@@ -136,18 +152,36 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
if f.Name != ft.Name { if f.Name != ft.Name {
t.Errorf("name=%q, want %q", f.Name, ft.Name) t.Errorf("name=%q, want %q", f.Name, ft.Name)
} }
mtime, err := time.Parse("01-02-06 15:04:05", ft.Mtime)
if err != nil {
t.Error(err)
return
}
if got, want := f.Mtime_ns()/1e9, mtime.Seconds(); got != want {
t.Errorf("%s: mtime=%s (%d); want %s (%d)", f.Name, time.SecondsToUTC(got), got, mtime, want)
}
size0 := f.UncompressedSize
var b bytes.Buffer var b bytes.Buffer
r, err := f.Open() r, err := f.Open()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
if size1 := f.UncompressedSize; size0 != size1 {
t.Errorf("file %q changed f.UncompressedSize from %d to %d", f.Name, size0, size1)
}
_, err = io.Copy(&b, r) _, err = io.Copy(&b, r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
r.Close() r.Close()
var c []byte var c []byte
if len(ft.Content) != 0 { if len(ft.Content) != 0 {
c = ft.Content c = ft.Content
...@@ -155,10 +189,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) { ...@@ -155,10 +189,12 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
t.Error(err) t.Error(err)
return return
} }
if b.Len() != len(c) { if b.Len() != len(c) {
t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c)) t.Errorf("%s: len=%d, want %d", f.Name, b.Len(), len(c))
return return
} }
for i, b := range b.Bytes() { for i, b := range b.Bytes() {
if b != c[i] { if b != c[i] {
t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i]) t.Errorf("%s: content[%d]=%q want %q", f.Name, i, b, c[i])
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package zip provides support for reading and writing ZIP archives.
See: http://www.pkware.com/documents/casestudies/APPNOTE.TXT
This package does not support ZIP64 or disk spanning.
*/
package zip package zip
import "os"
import "time"
// Compression methods.
const (
Store uint16 = 0
Deflate uint16 = 8
)
const ( const (
fileHeaderSignature = 0x04034b50 fileHeaderSignature = 0x04034b50
directoryHeaderSignature = 0x02014b50 directoryHeaderSignature = 0x02014b50
directoryEndSignature = 0x06054b50 directoryEndSignature = 0x06054b50
fileHeaderLen = 30 // + filename + extra
directoryHeaderLen = 46 // + filename + extra + comment
directoryEndLen = 22 // + comment
dataDescriptorLen = 12 dataDescriptorLen = 12
) )
...@@ -13,8 +36,8 @@ type FileHeader struct { ...@@ -13,8 +36,8 @@ type FileHeader struct {
ReaderVersion uint16 ReaderVersion uint16
Flags uint16 Flags uint16
Method uint16 Method uint16
ModifiedTime uint16 ModifiedTime uint16 // MS-DOS time
ModifiedDate uint16 ModifiedDate uint16 // MS-DOS date
CRC32 uint32 CRC32 uint32
CompressedSize uint32 CompressedSize uint32
UncompressedSize uint32 UncompressedSize uint32
...@@ -32,3 +55,37 @@ type directoryEnd struct { ...@@ -32,3 +55,37 @@ type directoryEnd struct {
commentLen uint16 commentLen uint16
comment string comment string
} }
func recoverError(err *os.Error) {
if e := recover(); e != nil {
if osErr, ok := e.(os.Error); ok {
*err = osErr
return
}
panic(e)
}
}
// msDosTimeToTime converts an MS-DOS date and time into a time.Time.
// The resolution is 2s.
// See: http://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Time{
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
Year: int64(dosDate>>9 + 1980),
Month: int(dosDate >> 5 & 0xf),
Day: int(dosDate & 0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
Hour: int(dosTime >> 11),
Minute: int(dosTime >> 5 & 0x3f),
Second: int(dosTime & 0x1f * 2),
}
}
// Mtime_ns returns the modified time in ns since epoch.
// The resolution is 2s.
func (h *FileHeader) Mtime_ns() int64 {
t := msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
return t.Seconds() * 1e9
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bufio"
"compress/flate"
"encoding/binary"
"hash"
"hash/crc32"
"io"
"os"
)
// TODO(adg): support zip file comments
// TODO(adg): support specifying deflate level
// Writer implements a zip file writer.
type Writer struct {
*countWriter
dir []*header
last *fileWriter
closed bool
}
type header struct {
*FileHeader
offset uint32
}
// NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer {
return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}}
}
// Close finishes writing the zip file by writing the central directory.
// It does not (and can not) close the underlying writer.
func (w *Writer) Close() (err os.Error) {
if w.last != nil && !w.last.closed {
if err = w.last.close(); err != nil {
return
}
w.last = nil
}
if w.closed {
return os.NewError("zip: writer closed twice")
}
w.closed = true
defer recoverError(&err)
// write central directory
start := w.count
for _, h := range w.dir {
write(w, uint32(directoryHeaderSignature))
write(w, h.CreatorVersion)
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
write(w, uint16(len(h.Comment)))
write(w, uint16(0)) // disk number start
write(w, uint16(0)) // internal file attributes
write(w, uint32(0)) // external file attributes
write(w, h.offset)
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
writeBytes(w, []byte(h.Comment))
}
end := w.count
// write end record
write(w, uint32(directoryEndSignature))
write(w, uint16(0)) // disk number
write(w, uint16(0)) // disk number where directory starts
write(w, uint16(len(w.dir))) // number of entries this disk
write(w, uint16(len(w.dir))) // number of entries total
write(w, uint32(end-start)) // size of directory
write(w, uint32(start)) // start of directory
write(w, uint16(0)) // size of comment
return w.w.(*bufio.Writer).Flush()
}
// Create adds a file to the zip file using the provided name.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) Create(name string) (io.Writer, os.Error) {
header := &FileHeader{
Name: name,
Method: Deflate,
}
return w.CreateHeader(header)
}
// CreateHeader adds a file to the zip file using the provided FileHeader
// for the file metadata.
// It returns a Writer to which the file contents should be written.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, os.Error) {
if w.last != nil && !w.last.closed {
if err := w.last.close(); err != nil {
return nil, err
}
}
fh.Flags |= 0x8 // we will write a data descriptor
fh.CreatorVersion = 0x14
fh.ReaderVersion = 0x14
fw := &fileWriter{
zipw: w,
compCount: &countWriter{w: w},
crc32: crc32.NewIEEE(),
}
switch fh.Method {
case Store:
fw.comp = nopCloser{fw.compCount}
case Deflate:
fw.comp = flate.NewWriter(fw.compCount, 5)
default:
return nil, UnsupportedMethod
}
fw.rawCount = &countWriter{w: fw.comp}
h := &header{
FileHeader: fh,
offset: uint32(w.count),
}
w.dir = append(w.dir, h)
fw.header = h
if err := writeHeader(w, fh); err != nil {
return nil, err
}
w.last = fw
return fw, nil
}
func writeHeader(w io.Writer, h *FileHeader) (err os.Error) {
defer recoverError(&err)
write(w, uint32(fileHeaderSignature))
write(w, h.ReaderVersion)
write(w, h.Flags)
write(w, h.Method)
write(w, h.ModifiedTime)
write(w, h.ModifiedDate)
write(w, h.CRC32)
write(w, h.CompressedSize)
write(w, h.UncompressedSize)
write(w, uint16(len(h.Name)))
write(w, uint16(len(h.Extra)))
writeBytes(w, []byte(h.Name))
writeBytes(w, h.Extra)
return nil
}
type fileWriter struct {
*header
zipw io.Writer
rawCount *countWriter
comp io.WriteCloser
compCount *countWriter
crc32 hash.Hash32
closed bool
}
func (w *fileWriter) Write(p []byte) (int, os.Error) {
if w.closed {
return 0, os.NewError("zip: write to closed file")
}
w.crc32.Write(p)
return w.rawCount.Write(p)
}
func (w *fileWriter) close() (err os.Error) {
if w.closed {
return os.NewError("zip: file closed twice")
}
w.closed = true
if err = w.comp.Close(); err != nil {
return
}
// update FileHeader
fh := w.header.FileHeader
fh.CRC32 = w.crc32.Sum32()
fh.CompressedSize = uint32(w.compCount.count)
fh.UncompressedSize = uint32(w.rawCount.count)
// write data descriptor
defer recoverError(&err)
write(w.zipw, fh.CRC32)
write(w.zipw, fh.CompressedSize)
write(w.zipw, fh.UncompressedSize)
return nil
}
type countWriter struct {
w io.Writer
count int64
}
func (w *countWriter) Write(p []byte) (int, os.Error) {
n, err := w.w.Write(p)
w.count += int64(n)
return n, err
}
type nopCloser struct {
io.Writer
}
func (w nopCloser) Close() os.Error {
return nil
}
func write(w io.Writer, data interface{}) {
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
panic(err)
}
}
func writeBytes(w io.Writer, b []byte) {
n, err := w.Write(b)
if err != nil {
panic(err)
}
if n != len(b) {
panic(io.ErrShortWrite)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"bytes"
"io/ioutil"
"rand"
"testing"
)
// TODO(adg): a more sophisticated test suite
const testString = "Rabbits, guinea pigs, gophers, marsupial rats, and quolls."
func TestWriter(t *testing.T) {
largeData := make([]byte, 1<<17)
for i := range largeData {
largeData[i] = byte(rand.Int())
}
// write a zip file
buf := new(bytes.Buffer)
w := NewWriter(buf)
testCreate(t, w, "foo", []byte(testString), Store)
testCreate(t, w, "bar", largeData, Deflate)
if err := w.Close(); err != nil {
t.Fatal(err)
}
// read it back
r, err := NewReader(sliceReaderAt(buf.Bytes()), int64(buf.Len()))
if err != nil {
t.Fatal(err)
}
testReadFile(t, r.File[0], []byte(testString))
testReadFile(t, r.File[1], largeData)
}
func testCreate(t *testing.T, w *Writer, name string, data []byte, method uint16) {
header := &FileHeader{
Name: name,
Method: method,
}
f, err := w.CreateHeader(header)
if err != nil {
t.Fatal(err)
}
_, err = f.Write(data)
if err != nil {
t.Fatal(err)
}
}
func testReadFile(t *testing.T, f *File, data []byte) {
rc, err := f.Open()
if err != nil {
t.Fatal("opening:", err)
}
b, err := ioutil.ReadAll(rc)
if err != nil {
t.Fatal("reading:", err)
}
err = rc.Close()
if err != nil {
t.Fatal("closing:", err)
}
if !bytes.Equal(b, data) {
t.Errorf("File contents %q, want %q", b, data)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests that involve both reading and writing.
package zip
import (
"bytes"
"fmt"
"os"
"testing"
)
type stringReaderAt string
func (s stringReaderAt) ReadAt(p []byte, off int64) (n int, err os.Error) {
if off >= int64(len(s)) {
return 0, os.EOF
}
n = copy(p, s[off:])
return
}
func TestOver65kFiles(t *testing.T) {
if testing.Short() {
t.Logf("slow test; skipping")
return
}
buf := new(bytes.Buffer)
w := NewWriter(buf)
const nFiles = (1 << 16) + 42
for i := 0; i < nFiles; i++ {
_, err := w.Create(fmt.Sprintf("%d.dat", i))
if err != nil {
t.Fatalf("creating file %d: %v", i, err)
}
}
if err := w.Close(); err != nil {
t.Fatalf("Writer.Close: %v", err)
}
rat := stringReaderAt(buf.String())
zr, err := NewReader(rat, int64(len(rat)))
if err != nil {
t.Fatalf("NewReader: %v", err)
}
if got := len(zr.File); got != nFiles {
t.Fatalf("File contains %d files, want %d", got, nFiles)
}
for i := 0; i < nFiles; i++ {
want := fmt.Sprintf("%d.dat", i)
if zr.File[i].Name != want {
t.Fatalf("File(%d) = %q, want %q", i, zr.File[i].Name, want)
}
}
}
...@@ -20,6 +20,7 @@ package asn1 ...@@ -20,6 +20,7 @@ package asn1
// everything by any means. // everything by any means.
import ( import (
"big"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
...@@ -88,6 +89,27 @@ func parseInt(bytes []byte) (int, os.Error) { ...@@ -88,6 +89,27 @@ func parseInt(bytes []byte) (int, os.Error) {
return int(ret64), nil return int(ret64), nil
} }
var bigOne = big.NewInt(1)
// parseBigInt treats the given bytes as a big-endian, signed integer and returns
// the result.
func parseBigInt(bytes []byte) *big.Int {
ret := new(big.Int)
if len(bytes) > 0 && bytes[0]&0x80 == 0x80 {
// This is a negative number.
notBytes := make([]byte, len(bytes))
for i := range notBytes {
notBytes[i] = ^bytes[i]
}
ret.SetBytes(notBytes)
ret.Add(ret, bigOne)
ret.Neg(ret)
return ret
}
ret.SetBytes(bytes)
return ret
}
// BIT STRING // BIT STRING
// BitString is the structure to use when you want an ASN.1 BIT STRING type. A // BitString is the structure to use when you want an ASN.1 BIT STRING type. A
...@@ -127,7 +149,7 @@ func (b BitString) RightAlign() []byte { ...@@ -127,7 +149,7 @@ func (b BitString) RightAlign() []byte {
return a return a
} }
// parseBitString parses an ASN.1 bit string from the given byte array and returns it. // parseBitString parses an ASN.1 bit string from the given byte slice and returns it.
func parseBitString(bytes []byte) (ret BitString, err os.Error) { func parseBitString(bytes []byte) (ret BitString, err os.Error) {
if len(bytes) == 0 { if len(bytes) == 0 {
err = SyntaxError{"zero length BIT STRING"} err = SyntaxError{"zero length BIT STRING"}
...@@ -164,9 +186,9 @@ func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool { ...@@ -164,9 +186,9 @@ func (oi ObjectIdentifier) Equal(other ObjectIdentifier) bool {
return true return true
} }
// parseObjectIdentifier parses an OBJECT IDENTIFER from the given bytes and // parseObjectIdentifier parses an OBJECT IDENTIFIER from the given bytes and
// returns it. An object identifer is a sequence of variable length integers // returns it. An object identifier is a sequence of variable length integers
// that are assigned in a hierarachy. // that are assigned in a hierarchy.
func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) { func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
if len(bytes) == 0 { if len(bytes) == 0 {
err = SyntaxError{"zero length OBJECT IDENTIFIER"} err = SyntaxError{"zero length OBJECT IDENTIFIER"}
...@@ -198,14 +220,13 @@ func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) { ...@@ -198,14 +220,13 @@ func parseObjectIdentifier(bytes []byte) (s []int, err os.Error) {
// An Enumerated is represented as a plain int. // An Enumerated is represented as a plain int.
type Enumerated int type Enumerated int
// FLAG // FLAG
// A Flag accepts any data and is set to true if present. // A Flag accepts any data and is set to true if present.
type Flag bool type Flag bool
// parseBase128Int parses a base-128 encoded int from the given offset in the // parseBase128Int parses a base-128 encoded int from the given offset in the
// given byte array. It returns the value and the new offset. // given byte slice. It returns the value and the new offset.
func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) { func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err os.Error) {
offset = initOffset offset = initOffset
for shifted := 0; offset < len(bytes); shifted++ { for shifted := 0; offset < len(bytes); shifted++ {
...@@ -237,7 +258,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) { ...@@ -237,7 +258,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err os.Error) {
return return
} }
// parseGeneralizedTime parses the GeneralizedTime from the given byte array // parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time. // and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) { func parseGeneralizedTime(bytes []byte) (ret *time.Time, err os.Error) {
return time.Parse("20060102150405Z0700", string(bytes)) return time.Parse("20060102150405Z0700", string(bytes))
...@@ -269,7 +290,7 @@ func isPrintable(b byte) bool { ...@@ -269,7 +290,7 @@ func isPrintable(b byte) bool {
b == ':' || b == ':' ||
b == '=' || b == '=' ||
b == '?' || b == '?' ||
// This is techincally not allowed in a PrintableString. // This is technically not allowed in a PrintableString.
// However, x509 certificates with wildcard strings don't // However, x509 certificates with wildcard strings don't
// always use the correct string type so we permit it. // always use the correct string type so we permit it.
b == '*' b == '*'
...@@ -278,7 +299,7 @@ func isPrintable(b byte) bool { ...@@ -278,7 +299,7 @@ func isPrintable(b byte) bool {
// IA5String // IA5String
// parseIA5String parses a ASN.1 IA5String (ASCII string) from the given // parseIA5String parses a ASN.1 IA5String (ASCII string) from the given
// byte array and returns it. // byte slice and returns it.
func parseIA5String(bytes []byte) (ret string, err os.Error) { func parseIA5String(bytes []byte) (ret string, err os.Error) {
for _, b := range bytes { for _, b := range bytes {
if b >= 0x80 { if b >= 0x80 {
...@@ -293,11 +314,19 @@ func parseIA5String(bytes []byte) (ret string, err os.Error) { ...@@ -293,11 +314,19 @@ func parseIA5String(bytes []byte) (ret string, err os.Error) {
// T61String // T61String
// parseT61String parses a ASN.1 T61String (8-bit clean string) from the given // parseT61String parses a ASN.1 T61String (8-bit clean string) from the given
// byte array and returns it. // byte slice and returns it.
func parseT61String(bytes []byte) (ret string, err os.Error) { func parseT61String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil return string(bytes), nil
} }
// UTF8String
// parseUTF8String parses a ASN.1 UTF8String (raw UTF-8) from the given byte
// array and returns it.
func parseUTF8String(bytes []byte) (ret string, err os.Error) {
return string(bytes), nil
}
// A RawValue represents an undecoded ASN.1 object. // A RawValue represents an undecoded ASN.1 object.
type RawValue struct { type RawValue struct {
Class, Tag int Class, Tag int
...@@ -314,7 +343,7 @@ type RawContent []byte ...@@ -314,7 +343,7 @@ type RawContent []byte
// Tagging // Tagging
// parseTagAndLength parses an ASN.1 tag and length pair from the given offset // parseTagAndLength parses an ASN.1 tag and length pair from the given offset
// into a byte array. It returns the parsed data and the new offset. SET and // into a byte slice. It returns the parsed data and the new offset. SET and
// SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we // SET OF (tag 17) are mapped to SEQUENCE and SEQUENCE OF (tag 16) since we
// don't distinguish between ordered and unordered objects in this code. // don't distinguish between ordered and unordered objects in this code.
func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) { func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset int, err os.Error) {
...@@ -371,7 +400,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i ...@@ -371,7 +400,7 @@ func parseTagAndLength(bytes []byte, initOffset int) (ret tagAndLength, offset i
} }
// parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse // parseSequenceOf is used for SEQUENCE OF and SET OF values. It tries to parse
// a number of ASN.1 values from the given byte array and returns them as a // a number of ASN.1 values from the given byte slice and returns them as a
// slice of Go values of the given type. // slice of Go values of the given type.
func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) { func parseSequenceOf(bytes []byte, sliceType reflect.Type, elemType reflect.Type) (ret reflect.Value, err os.Error) {
expectedTag, compoundType, ok := getUniversalType(elemType) expectedTag, compoundType, ok := getUniversalType(elemType)
...@@ -425,6 +454,7 @@ var ( ...@@ -425,6 +454,7 @@ var (
timeType = reflect.TypeOf(&time.Time{}) timeType = reflect.TypeOf(&time.Time{})
rawValueType = reflect.TypeOf(RawValue{}) rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.TypeOf(RawContent(nil)) rawContentsType = reflect.TypeOf(RawContent(nil))
bigIntType = reflect.TypeOf(new(big.Int))
) )
// invalidLength returns true iff offset + length > sliceLength, or if the // invalidLength returns true iff offset + length > sliceLength, or if the
...@@ -433,7 +463,7 @@ func invalidLength(offset, length, sliceLength int) bool { ...@@ -433,7 +463,7 @@ func invalidLength(offset, length, sliceLength int) bool {
return offset+length < offset || offset+length > sliceLength return offset+length < offset || offset+length > sliceLength
} }
// parseField is the main parsing function. Given a byte array and an offset // parseField is the main parsing function. Given a byte slice and an offset
// into the array, it will try to parse a suitable ASN.1 value out and store it // into the array, it will try to parse a suitable ASN.1 value out and store it
// in the given Value. // in the given Value.
func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) { func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParameters) (offset int, err os.Error) {
...@@ -550,16 +580,15 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -550,16 +580,15 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
} }
// Special case for strings: PrintableString and IA5String both map to // Special case for strings: all the ASN.1 string types map to the Go
// the Go type string. getUniversalType returns the tag for // type string. getUniversalType returns the tag for PrintableString
// PrintableString when it sees a string so, if we see an IA5String on // when it sees a string, so if we see a different string type on the
// the wire, we change the universal type to match. // wire, we change the universal type to match.
if universalTag == tagPrintableString && t.tag == tagIA5String { if universalTag == tagPrintableString {
universalTag = tagIA5String switch t.tag {
} case tagIA5String, tagGeneralString, tagT61String, tagUTF8String:
// Likewise for GeneralString universalTag = t.tag
if universalTag == tagPrintableString && t.tag == tagGeneralString { }
universalTag = tagGeneralString
} }
// Special case for time: UTCTime and GeneralizedTime both map to the // Special case for time: UTCTime and GeneralizedTime both map to the
...@@ -639,6 +668,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -639,6 +668,10 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
case flagType: case flagType:
v.SetBool(true) v.SetBool(true)
return return
case bigIntType:
parsedInt := parseBigInt(innerBytes)
v.Set(reflect.ValueOf(parsedInt))
return
} }
switch val := v; val.Kind() { switch val := v; val.Kind() {
case reflect.Bool: case reflect.Bool:
...@@ -648,23 +681,21 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -648,23 +681,21 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
err = err1 err = err1
return return
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int32:
switch val.Type().Kind() { parsedInt, err1 := parseInt(innerBytes)
case reflect.Int: if err1 == nil {
parsedInt, err1 := parseInt(innerBytes) val.SetInt(int64(parsedInt))
if err1 == nil {
val.SetInt(int64(parsedInt))
}
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
} }
err = err1
return
case reflect.Int64:
parsedInt, err1 := parseInt64(innerBytes)
if err1 == nil {
val.SetInt(parsedInt)
}
err = err1
return
// TODO(dfc) Add support for the remaining integer types
case reflect.Struct: case reflect.Struct:
structType := fieldType structType := fieldType
...@@ -680,7 +711,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -680,7 +711,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
if i == 0 && field.Type == rawContentsType { if i == 0 && field.Type == rawContentsType {
continue continue
} }
innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag)) innerOffset, err = parseField(val.Field(i), innerBytes, innerOffset, parseFieldParameters(field.Tag.Get("asn1")))
if err != nil { if err != nil {
return return
} }
...@@ -711,6 +742,8 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -711,6 +742,8 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
v, err = parseIA5String(innerBytes) v, err = parseIA5String(innerBytes)
case tagT61String: case tagT61String:
v, err = parseT61String(innerBytes) v, err = parseT61String(innerBytes)
case tagUTF8String:
v, err = parseUTF8String(innerBytes)
case tagGeneralString: case tagGeneralString:
// GeneralString is specified in ISO-2022/ECMA-35, // GeneralString is specified in ISO-2022/ECMA-35,
// A brief review suggests that it includes structures // A brief review suggests that it includes structures
...@@ -725,7 +758,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam ...@@ -725,7 +758,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
} }
return return
} }
err = StructuralError{"unknown Go type"} err = StructuralError{"unsupported: " + v.Type().String()}
return return
} }
...@@ -752,7 +785,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { ...@@ -752,7 +785,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs // Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. // being written to must use upper case field names.
// //
// An ASN.1 INTEGER can be written to an int or int64. // An ASN.1 INTEGER can be written to an int, int32 or int64.
// If the encoded value does not fit in the Go type, // If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error. // Unmarshal returns a parse error.
// //
......
...@@ -42,6 +42,64 @@ func TestParseInt64(t *testing.T) { ...@@ -42,6 +42,64 @@ func TestParseInt64(t *testing.T) {
} }
} }
type int32Test struct {
in []byte
ok bool
out int32
}
var int32TestData = []int32Test{
{[]byte{0x00}, true, 0},
{[]byte{0x7f}, true, 127},
{[]byte{0x00, 0x80}, true, 128},
{[]byte{0x01, 0x00}, true, 256},
{[]byte{0x80}, true, -128},
{[]byte{0xff, 0x7f}, true, -129},
{[]byte{0xff, 0xff, 0xff, 0xff}, true, -1},
{[]byte{0xff}, true, -1},
{[]byte{0x80, 0x00, 0x00, 0x00}, true, -2147483648},
{[]byte{0x80, 0x00, 0x00, 0x00, 0x00}, false, 0},
}
func TestParseInt32(t *testing.T) {
for i, test := range int32TestData {
ret, err := parseInt(test.in)
if (err == nil) != test.ok {
t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok)
}
if test.ok && int32(ret) != test.out {
t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out)
}
}
}
var bigIntTests = []struct {
in []byte
base10 string
}{
{[]byte{0xff}, "-1"},
{[]byte{0x00}, "0"},
{[]byte{0x01}, "1"},
{[]byte{0x00, 0xff}, "255"},
{[]byte{0xff, 0x00}, "-256"},
{[]byte{0x01, 0x00}, "256"},
}
func TestParseBigInt(t *testing.T) {
for i, test := range bigIntTests {
ret := parseBigInt(test.in)
if ret.String() != test.base10 {
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
}
fw := newForkableWriter()
marshalBigInt(fw, ret)
result := fw.Bytes()
if !bytes.Equal(result, test.in) {
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
}
}
}
type bitStringTest struct { type bitStringTest struct {
in []byte in []byte
ok bool ok bool
...@@ -148,10 +206,10 @@ type timeTest struct { ...@@ -148,10 +206,10 @@ type timeTest struct {
} }
var utcTestData = []timeTest{ var utcTestData = []timeTest{
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}}, {"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, -7 * 60 * 60, ""}},
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}}, {"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 0, 7*60*60 + 30*60, ""}},
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, "UTC"}}, {"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, 0, "UTC"}},
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, "UTC"}}, {"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, 0, "UTC"}},
{"a10506234540Z", false, nil}, {"a10506234540Z", false, nil},
{"91a506234540Z", false, nil}, {"91a506234540Z", false, nil},
{"9105a6234540Z", false, nil}, {"9105a6234540Z", false, nil},
...@@ -177,10 +235,10 @@ func TestUTCTime(t *testing.T) { ...@@ -177,10 +235,10 @@ func TestUTCTime(t *testing.T) {
} }
var generalizedTimeTestData = []timeTest{ var generalizedTimeTestData = []timeTest{
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, "UTC"}}, {"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 0, "UTC"}},
{"20100102030405", false, nil}, {"20100102030405", false, nil},
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 6*60*60 + 7*60, ""}}, {"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, 6*60*60 + 7*60, ""}},
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, -6*60*60 - 7*60, ""}}, {"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, -6*60*60 - 7*60, ""}},
} }
func TestGeneralizedTime(t *testing.T) { func TestGeneralizedTime(t *testing.T) {
...@@ -272,11 +330,11 @@ type TestObjectIdentifierStruct struct { ...@@ -272,11 +330,11 @@ type TestObjectIdentifierStruct struct {
} }
type TestContextSpecificTags struct { type TestContextSpecificTags struct {
A int "tag:1" A int `asn1:"tag:1"`
} }
type TestContextSpecificTags2 struct { type TestContextSpecificTags2 struct {
A int "explicit,tag:1" A int `asn1:"explicit,tag:1"`
B int B int
} }
...@@ -326,7 +384,7 @@ type Certificate struct { ...@@ -326,7 +384,7 @@ type Certificate struct {
} }
type TBSCertificate struct { type TBSCertificate struct {
Version int "optional,explicit,default:0,tag:0" Version int `asn1:"optional,explicit,default:0,tag:0"`
SerialNumber RawValue SerialNumber RawValue
SignatureAlgorithm AlgorithmIdentifier SignatureAlgorithm AlgorithmIdentifier
Issuer RDNSequence Issuer RDNSequence
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"strings" "strings"
) )
// ASN.1 objects have metadata preceeding them: // ASN.1 objects have metadata preceding them:
// the tag: the type of the object // the tag: the type of the object
// a flag denoting if this object is compound or not // a flag denoting if this object is compound or not
// the class type: the namespace of the tag // the class type: the namespace of the tag
...@@ -25,6 +25,7 @@ const ( ...@@ -25,6 +25,7 @@ const (
tagOctetString = 4 tagOctetString = 4
tagOID = 6 tagOID = 6
tagEnum = 10 tagEnum = 10
tagUTF8String = 12
tagSequence = 16 tagSequence = 16
tagSet = 17 tagSet = 17
tagPrintableString = 19 tagPrintableString = 19
...@@ -83,7 +84,7 @@ type fieldParameters struct { ...@@ -83,7 +84,7 @@ type fieldParameters struct {
// parseFieldParameters will parse it into a fieldParameters structure, // parseFieldParameters will parse it into a fieldParameters structure,
// ignoring unknown parts of the string. // ignoring unknown parts of the string.
func parseFieldParameters(str string) (ret fieldParameters) { func parseFieldParameters(str string) (ret fieldParameters) {
for _, part := range strings.Split(str, ",", -1) { for _, part := range strings.Split(str, ",") {
switch { switch {
case part == "optional": case part == "optional":
ret.optional = true ret.optional = true
...@@ -132,6 +133,8 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) { ...@@ -132,6 +133,8 @@ func getUniversalType(t reflect.Type) (tagNumber int, isCompound, ok bool) {
return tagUTCTime, false, true return tagUTCTime, false, true
case enumeratedType: case enumeratedType:
return tagEnum, false, true return tagEnum, false, true
case bigIntType:
return tagInteger, false, true
} }
switch t.Kind() { switch t.Kind() {
case reflect.Bool: case reflect.Bool:
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package asn1 package asn1
import ( import (
"big"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -125,6 +126,43 @@ func int64Length(i int64) (numBytes int) { ...@@ -125,6 +126,43 @@ func int64Length(i int64) (numBytes int) {
return return
} }
func marshalBigInt(out *forkableWriter, n *big.Int) (err os.Error) {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
// form. So we'll subtract 1 and invert. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
nMinus1.Sub(nMinus1, bigOne)
bytes := nMinus1.Bytes()
for i := range bytes {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
err = out.WriteByte(0xff)
if err != nil {
return
}
}
_, err = out.Write(bytes)
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
err = out.WriteByte(0x00)
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
err = out.WriteByte(0)
if err != nil {
return
}
}
_, err = out.Write(bytes)
}
return
}
func marshalLength(out *forkableWriter, i int) (err os.Error) { func marshalLength(out *forkableWriter, i int) (err os.Error) {
n := lengthLength(i) n := lengthLength(i)
...@@ -334,6 +372,8 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -334,6 +372,8 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
return marshalBitString(out, value.Interface().(BitString)) return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType: case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
case bigIntType:
return marshalBigInt(out, value.Interface().(*big.Int))
} }
switch v := value; v.Kind() { switch v := value; v.Kind() {
...@@ -351,7 +391,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -351,7 +391,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
startingField := 0 startingField := 0
// If the first element of the structure is a non-empty // If the first element of the structure is a non-empty
// RawContents, then we don't bother serialising the rest. // RawContents, then we don't bother serializing the rest.
if t.NumField() > 0 && t.Field(0).Type == rawContentsType { if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
s := v.Field(0) s := v.Field(0)
if s.Len() > 0 { if s.Len() > 0 {
...@@ -361,7 +401,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -361,7 +401,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
} }
/* The RawContents will contain the tag and /* The RawContents will contain the tag and
* length fields but we'll also be writing * length fields but we'll also be writing
* those outselves, so we strip them out of * those ourselves, so we strip them out of
* bytes */ * bytes */
_, err = out.Write(stripTagAndLength(bytes)) _, err = out.Write(stripTagAndLength(bytes))
return return
...@@ -373,7 +413,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter ...@@ -373,7 +413,7 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
for i := startingField; i < t.NumField(); i++ { for i := startingField; i < t.NumField(); i++ {
var pre *forkableWriter var pre *forkableWriter
pre, out = out.fork() pre, out = out.fork()
err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag)) err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
if err != nil { if err != nil {
return return
} }
...@@ -418,6 +458,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -418,6 +458,10 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return marshalField(out, v.Elem(), params) return marshalField(out, v.Elem(), params)
} }
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
if v.Type() == rawValueType { if v.Type() == rawValueType {
rv := v.Interface().(RawValue) rv := v.Interface().(RawValue)
err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}) err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
...@@ -428,10 +472,6 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) ...@@ -428,10 +472,6 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
return return
} }
if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
return
}
tag, isCompound, ok := getUniversalType(v.Type()) tag, isCompound, ok := getUniversalType(v.Type())
if !ok { if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
......
...@@ -30,19 +30,23 @@ type rawContentsStruct struct { ...@@ -30,19 +30,23 @@ type rawContentsStruct struct {
} }
type implicitTagTest struct { type implicitTagTest struct {
A int "implicit,tag:5" A int `asn1:"implicit,tag:5"`
} }
type explicitTagTest struct { type explicitTagTest struct {
A int "explicit,tag:5" A int `asn1:"explicit,tag:5"`
} }
type ia5StringTest struct { type ia5StringTest struct {
A string "ia5" A string `asn1:"ia5"`
} }
type printableStringTest struct { type printableStringTest struct {
A string "printable" A string `asn1:"printable"`
}
type optionalRawValueTest struct {
A RawValue `asn1:"optional"`
} }
type testSET []int type testSET []int
...@@ -102,6 +106,7 @@ var marshalTests = []marshalTest{ ...@@ -102,6 +106,7 @@ var marshalTests = []marshalTest{
"7878787878787878787878787878787878787878787878787878787878787878", "7878787878787878787878787878787878787878787878787878787878787878",
}, },
{ia5StringTest{"test"}, "3006160474657374"}, {ia5StringTest{"test"}, "3006160474657374"},
{optionalRawValueTest{}, "3000"},
{printableStringTest{"test"}, "3006130474657374"}, {printableStringTest{"test"}, "3006130474657374"},
{printableStringTest{"test*"}, "30071305746573742a"}, {printableStringTest{"test*"}, "30071305746573742a"},
{rawContentsStruct{nil, 64}, "3003020140"}, {rawContentsStruct{nil, 64}, "3003020140"},
......
...@@ -27,7 +27,6 @@ const ( ...@@ -27,7 +27,6 @@ const (
_M2 = _B2 - 1 // half digit mask _M2 = _B2 - 1 // half digit mask
) )
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Elementary operations on words // Elementary operations on words
// //
...@@ -43,7 +42,6 @@ func addWW_g(x, y, c Word) (z1, z0 Word) { ...@@ -43,7 +42,6 @@ func addWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x-y-c, with c == 0 or 1 // z1<<_W + z0 = x-y-c, with c == 0 or 1
func subWW_g(x, y, c Word) (z1, z0 Word) { func subWW_g(x, y, c Word) (z1, z0 Word) {
yc := y + c yc := y + c
...@@ -54,7 +52,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) { ...@@ -54,7 +52,6 @@ func subWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x*y // z1<<_W + z0 = x*y
func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) } func mulWW(x, y Word) (z1, z0 Word) { return mulWW_g(x, y) }
// Adapted from Warren, Hacker's Delight, p. 132. // Adapted from Warren, Hacker's Delight, p. 132.
...@@ -73,7 +70,6 @@ func mulWW_g(x, y Word) (z1, z0 Word) { ...@@ -73,7 +70,6 @@ func mulWW_g(x, y Word) (z1, z0 Word) {
return return
} }
// z1<<_W + z0 = x*y + c // z1<<_W + z0 = x*y + c
func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
z1, zz0 := mulWW(x, y) z1, zz0 := mulWW(x, y)
...@@ -83,7 +79,6 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) { ...@@ -83,7 +79,6 @@ func mulAddWWW_g(x, y, c Word) (z1, z0 Word) {
return return
} }
// Length of x in bits. // Length of x in bits.
func bitLen(x Word) (n int) { func bitLen(x Word) (n int) {
for ; x >= 0x100; x >>= 8 { for ; x >= 0x100; x >>= 8 {
...@@ -95,7 +90,6 @@ func bitLen(x Word) (n int) { ...@@ -95,7 +90,6 @@ func bitLen(x Word) (n int) {
return return
} }
// log2 computes the integer binary logarithm of x. // log2 computes the integer binary logarithm of x.
// The result is the integer n for which 2^n <= x < 2^(n+1). // The result is the integer n for which 2^n <= x < 2^(n+1).
// If x == 0, the result is -1. // If x == 0, the result is -1.
...@@ -103,13 +97,11 @@ func log2(x Word) int { ...@@ -103,13 +97,11 @@ func log2(x Word) int {
return bitLen(x) - 1 return bitLen(x) - 1
} }
// Number of leading zeros in x. // Number of leading zeros in x.
func leadingZeros(x Word) uint { func leadingZeros(x Word) uint {
return uint(_W - bitLen(x)) return uint(_W - bitLen(x))
} }
// q = (u1<<_W + u0 - r)/y // q = (u1<<_W + u0 - r)/y
func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) } func divWW(x1, x0, y Word) (q, r Word) { return divWW_g(x1, x0, y) }
// Adapted from Warren, Hacker's Delight, p. 152. // Adapted from Warren, Hacker's Delight, p. 152.
...@@ -155,7 +147,6 @@ again2: ...@@ -155,7 +147,6 @@ again2:
return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s return q1*_B2 + q0, (un21*_B2 + un0 - q0*v) >> s
} }
func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) } func addVV(z, x, y []Word) (c Word) { return addVV_g(z, x, y) }
func addVV_g(z, x, y []Word) (c Word) { func addVV_g(z, x, y []Word) (c Word) {
for i := range z { for i := range z {
...@@ -164,7 +155,6 @@ func addVV_g(z, x, y []Word) (c Word) { ...@@ -164,7 +155,6 @@ func addVV_g(z, x, y []Word) (c Word) {
return return
} }
func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) } func subVV(z, x, y []Word) (c Word) { return subVV_g(z, x, y) }
func subVV_g(z, x, y []Word) (c Word) { func subVV_g(z, x, y []Word) (c Word) {
for i := range z { for i := range z {
...@@ -173,7 +163,6 @@ func subVV_g(z, x, y []Word) (c Word) { ...@@ -173,7 +163,6 @@ func subVV_g(z, x, y []Word) (c Word) {
return return
} }
func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) } func addVW(z, x []Word, y Word) (c Word) { return addVW_g(z, x, y) }
func addVW_g(z, x []Word, y Word) (c Word) { func addVW_g(z, x []Word, y Word) (c Word) {
c = y c = y
...@@ -183,7 +172,6 @@ func addVW_g(z, x []Word, y Word) (c Word) { ...@@ -183,7 +172,6 @@ func addVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) } func subVW(z, x []Word, y Word) (c Word) { return subVW_g(z, x, y) }
func subVW_g(z, x []Word, y Word) (c Word) { func subVW_g(z, x []Word, y Word) (c Word) {
c = y c = y
...@@ -193,9 +181,8 @@ func subVW_g(z, x []Word, y Word) (c Word) { ...@@ -193,9 +181,8 @@ func subVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func shlVU(z, x []Word, s uint) (c Word) { return shlVU_g(z, x, s) }
func shlVW(z, x []Word, s Word) (c Word) { return shlVW_g(z, x, s) } func shlVU_g(z, x []Word, s uint) (c Word) {
func shlVW_g(z, x []Word, s Word) (c Word) {
if n := len(z); n > 0 { if n := len(z); n > 0 {
ŝ := _W - s ŝ := _W - s
w1 := x[n-1] w1 := x[n-1]
...@@ -210,9 +197,8 @@ func shlVW_g(z, x []Word, s Word) (c Word) { ...@@ -210,9 +197,8 @@ func shlVW_g(z, x []Word, s Word) (c Word) {
return return
} }
func shrVU(z, x []Word, s uint) (c Word) { return shrVU_g(z, x, s) }
func shrVW(z, x []Word, s Word) (c Word) { return shrVW_g(z, x, s) } func shrVU_g(z, x []Word, s uint) (c Word) {
func shrVW_g(z, x []Word, s Word) (c Word) {
if n := len(z); n > 0 { if n := len(z); n > 0 {
ŝ := _W - s ŝ := _W - s
w1 := x[0] w1 := x[0]
...@@ -227,7 +213,6 @@ func shrVW_g(z, x []Word, s Word) (c Word) { ...@@ -227,7 +213,6 @@ func shrVW_g(z, x []Word, s Word) (c Word) {
return return
} }
func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) } func mulAddVWW(z, x []Word, y, r Word) (c Word) { return mulAddVWW_g(z, x, y, r) }
func mulAddVWW_g(z, x []Word, y, r Word) (c Word) { func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
c = r c = r
...@@ -237,7 +222,6 @@ func mulAddVWW_g(z, x []Word, y, r Word) (c Word) { ...@@ -237,7 +222,6 @@ func mulAddVWW_g(z, x []Word, y, r Word) (c Word) {
return return
} }
func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) } func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVW_g(z, x, y) }
func addMulVVW_g(z, x []Word, y Word) (c Word) { func addMulVVW_g(z, x []Word, y Word) (c Word) {
for i := range z { for i := range z {
...@@ -248,7 +232,6 @@ func addMulVVW_g(z, x []Word, y Word) (c Word) { ...@@ -248,7 +232,6 @@ func addMulVVW_g(z, x []Word, y Word) (c Word) {
return return
} }
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) } func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { return divWVW_g(z, xn, x, y) }
func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) { func divWVW_g(z []Word, xn Word, x []Word, y Word) (r Word) {
r = xn r = xn
......
...@@ -11,8 +11,8 @@ func addVV(z, x, y []Word) (c Word) ...@@ -11,8 +11,8 @@ func addVV(z, x, y []Word) (c Word)
func subVV(z, x, y []Word) (c Word) func subVV(z, x, y []Word) (c Word)
func addVW(z, x []Word, y Word) (c Word) func addVW(z, x []Word, y Word) (c Word)
func subVW(z, x []Word, y Word) (c Word) func subVW(z, x []Word, y Word) (c Word)
func shlVW(z, x []Word, s Word) (c Word) func shlVU(z, x []Word, s uint) (c Word)
func shrVW(z, x []Word, s Word) (c Word) func shrVU(z, x []Word, s uint) (c Word)
func mulAddVWW(z, x []Word, y, r Word) (c Word) func mulAddVWW(z, x []Word, y, r Word) (c Word)
func addMulVVW(z, x []Word, y Word) (c Word) func addMulVVW(z, x []Word, y Word) (c Word)
func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) func divWVW(z []Word, xn Word, x []Word, y Word) (r Word)
...@@ -6,7 +6,6 @@ package big ...@@ -6,7 +6,6 @@ package big
import "testing" import "testing"
type funWW func(x, y, c Word) (z1, z0 Word) type funWW func(x, y, c Word) (z1, z0 Word)
type argWW struct { type argWW struct {
x, y, c, z1, z0 Word x, y, c, z1, z0 Word
...@@ -26,7 +25,6 @@ var sumWW = []argWW{ ...@@ -26,7 +25,6 @@ var sumWW = []argWW{
{_M, _M, 1, 1, _M}, {_M, _M, 1, 1, _M},
} }
func testFunWW(t *testing.T, msg string, f funWW, a argWW) { func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
z1, z0 := f(a.x, a.y, a.c) z1, z0 := f(a.x, a.y, a.c)
if z1 != a.z1 || z0 != a.z0 { if z1 != a.z1 || z0 != a.z0 {
...@@ -34,7 +32,6 @@ func testFunWW(t *testing.T, msg string, f funWW, a argWW) { ...@@ -34,7 +32,6 @@ func testFunWW(t *testing.T, msg string, f funWW, a argWW) {
} }
} }
func TestFunWW(t *testing.T) { func TestFunWW(t *testing.T) {
for _, a := range sumWW { for _, a := range sumWW {
arg := a arg := a
...@@ -51,7 +48,6 @@ func TestFunWW(t *testing.T) { ...@@ -51,7 +48,6 @@ func TestFunWW(t *testing.T) {
} }
} }
type funVV func(z, x, y []Word) (c Word) type funVV func(z, x, y []Word) (c Word)
type argVV struct { type argVV struct {
z, x, y nat z, x, y nat
...@@ -70,7 +66,6 @@ var sumVV = []argVV{ ...@@ -70,7 +66,6 @@ var sumVV = []argVV{
{nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1}, {nat{0, 0, 0, 0}, nat{_M, 0, _M, 0}, nat{1, _M, 0, _M}, 1},
} }
func testFunVV(t *testing.T, msg string, f funVV, a argVV) { func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y) c := f(z, a.x, a.y)
...@@ -85,7 +80,6 @@ func testFunVV(t *testing.T, msg string, f funVV, a argVV) { ...@@ -85,7 +80,6 @@ func testFunVV(t *testing.T, msg string, f funVV, a argVV) {
} }
} }
func TestFunVV(t *testing.T) { func TestFunVV(t *testing.T) {
for _, a := range sumVV { for _, a := range sumVV {
arg := a arg := a
...@@ -106,7 +100,6 @@ func TestFunVV(t *testing.T) { ...@@ -106,7 +100,6 @@ func TestFunVV(t *testing.T) {
} }
} }
type funVW func(z, x []Word, y Word) (c Word) type funVW func(z, x []Word, y Word) (c Word)
type argVW struct { type argVW struct {
z, x nat z, x nat
...@@ -169,7 +162,6 @@ var rshVW = []argVW{ ...@@ -169,7 +162,6 @@ var rshVW = []argVW{
{nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M}, {nat{_M, _M, _M >> 20}, nat{_M, _M, _M}, 20, _M << (_W - 20) & _M},
} }
func testFunVW(t *testing.T, msg string, f funVW, a argVW) { func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y) c := f(z, a.x, a.y)
...@@ -184,6 +176,11 @@ func testFunVW(t *testing.T, msg string, f funVW, a argVW) { ...@@ -184,6 +176,11 @@ func testFunVW(t *testing.T, msg string, f funVW, a argVW) {
} }
} }
func makeFunVW(f func(z, x []Word, s uint) (c Word)) funVW {
return func(z, x []Word, s Word) (c Word) {
return f(z, x, uint(s))
}
}
func TestFunVW(t *testing.T) { func TestFunVW(t *testing.T) {
for _, a := range sumVW { for _, a := range sumVW {
...@@ -196,20 +193,23 @@ func TestFunVW(t *testing.T) { ...@@ -196,20 +193,23 @@ func TestFunVW(t *testing.T) {
testFunVW(t, "subVW", subVW, arg) testFunVW(t, "subVW", subVW, arg)
} }
shlVW_g := makeFunVW(shlVU_g)
shlVW := makeFunVW(shlVU)
for _, a := range lshVW { for _, a := range lshVW {
arg := a arg := a
testFunVW(t, "shlVW_g", shlVW_g, arg) testFunVW(t, "shlVU_g", shlVW_g, arg)
testFunVW(t, "shlVW", shlVW, arg) testFunVW(t, "shlVU", shlVW, arg)
} }
shrVW_g := makeFunVW(shrVU_g)
shrVW := makeFunVW(shrVU)
for _, a := range rshVW { for _, a := range rshVW {
arg := a arg := a
testFunVW(t, "shrVW_g", shrVW_g, arg) testFunVW(t, "shrVU_g", shrVW_g, arg)
testFunVW(t, "shrVW", shrVW, arg) testFunVW(t, "shrVU", shrVW, arg)
} }
} }
type funVWW func(z, x []Word, y, r Word) (c Word) type funVWW func(z, x []Word, y, r Word) (c Word)
type argVWW struct { type argVWW struct {
z, x nat z, x nat
...@@ -243,7 +243,6 @@ var prodVWW = []argVWW{ ...@@ -243,7 +243,6 @@ var prodVWW = []argVWW{
{nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)}, {nat{_M<<7&_M + 1<<6, _M, _M, _M}, nat{_M, _M, _M, _M}, 1 << 7, 1 << 6, _M >> (_W - 7)},
} }
func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) { func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
z := make(nat, len(a.z)) z := make(nat, len(a.z))
c := f(z, a.x, a.y, a.r) c := f(z, a.x, a.y, a.r)
...@@ -258,7 +257,6 @@ func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) { ...@@ -258,7 +257,6 @@ func testFunVWW(t *testing.T, msg string, f funVWW, a argVWW) {
} }
} }
// TODO(gri) mulAddVWW and divWVW are symmetric operations but // TODO(gri) mulAddVWW and divWVW are symmetric operations but
// their signature is not symmetric. Try to unify. // their signature is not symmetric. Try to unify.
...@@ -285,7 +283,6 @@ func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) { ...@@ -285,7 +283,6 @@ func testFunWVW(t *testing.T, msg string, f funWVW, a argWVW) {
} }
} }
func TestFunVWW(t *testing.T) { func TestFunVWW(t *testing.T) {
for _, a := range prodVWW { for _, a := range prodVWW {
arg := a arg := a
...@@ -300,7 +297,6 @@ func TestFunVWW(t *testing.T) { ...@@ -300,7 +297,6 @@ func TestFunVWW(t *testing.T) {
} }
} }
var mulWWTests = []struct { var mulWWTests = []struct {
x, y Word x, y Word
q, r Word q, r Word
...@@ -309,7 +305,6 @@ var mulWWTests = []struct { ...@@ -309,7 +305,6 @@ var mulWWTests = []struct {
// 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4}, // 32 bit only: {0xc47dfa8c, 50911, 0x98a4, 0x998587f4},
} }
func TestMulWW(t *testing.T) { func TestMulWW(t *testing.T) {
for i, test := range mulWWTests { for i, test := range mulWWTests {
q, r := mulWW_g(test.x, test.y) q, r := mulWW_g(test.x, test.y)
...@@ -319,7 +314,6 @@ func TestMulWW(t *testing.T) { ...@@ -319,7 +314,6 @@ func TestMulWW(t *testing.T) {
} }
} }
var mulAddWWWTests = []struct { var mulAddWWWTests = []struct {
x, y, c Word x, y, c Word
q, r Word q, r Word
...@@ -331,7 +325,6 @@ var mulAddWWWTests = []struct { ...@@ -331,7 +325,6 @@ var mulAddWWWTests = []struct {
{_M, _M, _M, _M, 0}, {_M, _M, _M, _M, 0},
} }
func TestMulAddWWW(t *testing.T) { func TestMulAddWWW(t *testing.T) {
for i, test := range mulAddWWWTests { for i, test := range mulAddWWWTests {
q, r := mulAddWWW_g(test.x, test.y, test.c) q, r := mulAddWWW_g(test.x, test.y, test.c)
......
...@@ -19,10 +19,8 @@ import ( ...@@ -19,10 +19,8 @@ import (
"time" "time"
) )
var calibrate = flag.Bool("calibrate", false, "run calibration test") var calibrate = flag.Bool("calibrate", false, "run calibration test")
// measure returns the time to run f // measure returns the time to run f
func measure(f func()) int64 { func measure(f func()) int64 {
const N = 100 const N = 100
...@@ -34,7 +32,6 @@ func measure(f func()) int64 { ...@@ -34,7 +32,6 @@ func measure(f func()) int64 {
return (stop - start) / N return (stop - start) / N
} }
func computeThresholds() { func computeThresholds() {
fmt.Printf("Multiplication times for varying Karatsuba thresholds\n") fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
fmt.Printf("(run repeatedly for good results)\n") fmt.Printf("(run repeatedly for good results)\n")
...@@ -84,7 +81,6 @@ func computeThresholds() { ...@@ -84,7 +81,6 @@ func computeThresholds() {
} }
} }
func TestCalibrate(t *testing.T) { func TestCalibrate(t *testing.T) {
if *calibrate { if *calibrate {
computeThresholds() computeThresholds()
......
...@@ -13,13 +13,11 @@ import ( ...@@ -13,13 +13,11 @@ import (
"testing" "testing"
) )
type matrix struct { type matrix struct {
n, m int n, m int
a []*Rat a []*Rat
} }
func (a *matrix) at(i, j int) *Rat { func (a *matrix) at(i, j int) *Rat {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) { if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range") panic("index out of range")
...@@ -27,7 +25,6 @@ func (a *matrix) at(i, j int) *Rat { ...@@ -27,7 +25,6 @@ func (a *matrix) at(i, j int) *Rat {
return a.a[i*a.m+j] return a.a[i*a.m+j]
} }
func (a *matrix) set(i, j int, x *Rat) { func (a *matrix) set(i, j int, x *Rat) {
if !(0 <= i && i < a.n && 0 <= j && j < a.m) { if !(0 <= i && i < a.n && 0 <= j && j < a.m) {
panic("index out of range") panic("index out of range")
...@@ -35,7 +32,6 @@ func (a *matrix) set(i, j int, x *Rat) { ...@@ -35,7 +32,6 @@ func (a *matrix) set(i, j int, x *Rat) {
a.a[i*a.m+j] = x a.a[i*a.m+j] = x
} }
func newMatrix(n, m int) *matrix { func newMatrix(n, m int) *matrix {
if !(0 <= n && 0 <= m) { if !(0 <= n && 0 <= m) {
panic("illegal matrix") panic("illegal matrix")
...@@ -47,7 +43,6 @@ func newMatrix(n, m int) *matrix { ...@@ -47,7 +43,6 @@ func newMatrix(n, m int) *matrix {
return a return a
} }
func newUnit(n int) *matrix { func newUnit(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
...@@ -62,7 +57,6 @@ func newUnit(n int) *matrix { ...@@ -62,7 +57,6 @@ func newUnit(n int) *matrix {
return a return a
} }
func newHilbert(n int) *matrix { func newHilbert(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
...@@ -73,7 +67,6 @@ func newHilbert(n int) *matrix { ...@@ -73,7 +67,6 @@ func newHilbert(n int) *matrix {
return a return a
} }
func newInverseHilbert(n int) *matrix { func newInverseHilbert(n int) *matrix {
a := newMatrix(n, n) a := newMatrix(n, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
...@@ -98,7 +91,6 @@ func newInverseHilbert(n int) *matrix { ...@@ -98,7 +91,6 @@ func newInverseHilbert(n int) *matrix {
return a return a
} }
func (a *matrix) mul(b *matrix) *matrix { func (a *matrix) mul(b *matrix) *matrix {
if a.m != b.n { if a.m != b.n {
panic("illegal matrix multiply") panic("illegal matrix multiply")
...@@ -116,7 +108,6 @@ func (a *matrix) mul(b *matrix) *matrix { ...@@ -116,7 +108,6 @@ func (a *matrix) mul(b *matrix) *matrix {
return c return c
} }
func (a *matrix) eql(b *matrix) bool { func (a *matrix) eql(b *matrix) bool {
if a.n != b.n || a.m != b.m { if a.n != b.n || a.m != b.m {
return false return false
...@@ -131,7 +122,6 @@ func (a *matrix) eql(b *matrix) bool { ...@@ -131,7 +122,6 @@ func (a *matrix) eql(b *matrix) bool {
return true return true
} }
func (a *matrix) String() string { func (a *matrix) String() string {
s := "" s := ""
for i := 0; i < a.n; i++ { for i := 0; i < a.n; i++ {
...@@ -143,7 +133,6 @@ func (a *matrix) String() string { ...@@ -143,7 +133,6 @@ func (a *matrix) String() string {
return s return s
} }
func doHilbert(t *testing.T, n int) { func doHilbert(t *testing.T, n int) {
a := newHilbert(n) a := newHilbert(n)
b := newInverseHilbert(n) b := newInverseHilbert(n)
...@@ -160,12 +149,10 @@ func doHilbert(t *testing.T, n int) { ...@@ -160,12 +149,10 @@ func doHilbert(t *testing.T, n int) {
} }
} }
func TestHilbert(t *testing.T) { func TestHilbert(t *testing.T) {
doHilbert(t, 10) doHilbert(t, 10)
} }
func BenchmarkHilbert(b *testing.B) { func BenchmarkHilbert(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
doHilbert(nil, 10) doHilbert(nil, 10)
......
...@@ -6,7 +6,12 @@ ...@@ -6,7 +6,12 @@
package big package big
import "strings" import (
"encoding/binary"
"fmt"
"os"
"strings"
)
// A Rat represents a quotient a/b of arbitrary precision. The zero value for // A Rat represents a quotient a/b of arbitrary precision. The zero value for
// a Rat, 0/0, is not a legal Rat. // a Rat, 0/0, is not a legal Rat.
...@@ -15,13 +20,11 @@ type Rat struct { ...@@ -15,13 +20,11 @@ type Rat struct {
b nat b nat
} }
// NewRat creates a new Rat with numerator a and denominator b. // NewRat creates a new Rat with numerator a and denominator b.
func NewRat(a, b int64) *Rat { func NewRat(a, b int64) *Rat {
return new(Rat).SetFrac64(a, b) return new(Rat).SetFrac64(a, b)
} }
// SetFrac sets z to a/b and returns z. // SetFrac sets z to a/b and returns z.
func (z *Rat) SetFrac(a, b *Int) *Rat { func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.Set(a) z.a.Set(a)
...@@ -30,7 +33,6 @@ func (z *Rat) SetFrac(a, b *Int) *Rat { ...@@ -30,7 +33,6 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
return z.norm() return z.norm()
} }
// SetFrac64 sets z to a/b and returns z. // SetFrac64 sets z to a/b and returns z.
func (z *Rat) SetFrac64(a, b int64) *Rat { func (z *Rat) SetFrac64(a, b int64) *Rat {
z.a.SetInt64(a) z.a.SetInt64(a)
...@@ -42,7 +44,6 @@ func (z *Rat) SetFrac64(a, b int64) *Rat { ...@@ -42,7 +44,6 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
return z.norm() return z.norm()
} }
// SetInt sets z to x (by making a copy of x) and returns z. // SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat { func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x) z.a.Set(x)
...@@ -50,7 +51,6 @@ func (z *Rat) SetInt(x *Int) *Rat { ...@@ -50,7 +51,6 @@ func (z *Rat) SetInt(x *Int) *Rat {
return z return z
} }
// SetInt64 sets z to x and returns z. // SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat { func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x) z.a.SetInt64(x)
...@@ -58,7 +58,6 @@ func (z *Rat) SetInt64(x int64) *Rat { ...@@ -58,7 +58,6 @@ func (z *Rat) SetInt64(x int64) *Rat {
return z return z
} }
// Sign returns: // Sign returns:
// //
// -1 if x < 0 // -1 if x < 0
...@@ -69,13 +68,11 @@ func (x *Rat) Sign() int { ...@@ -69,13 +68,11 @@ func (x *Rat) Sign() int {
return x.a.Sign() return x.a.Sign()
} }
// IsInt returns true if the denominator of x is 1. // IsInt returns true if the denominator of x is 1.
func (x *Rat) IsInt() bool { func (x *Rat) IsInt() bool {
return len(x.b) == 1 && x.b[0] == 1 return len(x.b) == 1 && x.b[0] == 1
} }
// Num returns the numerator of z; it may be <= 0. // Num returns the numerator of z; it may be <= 0.
// The result is a reference to z's numerator; it // The result is a reference to z's numerator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to z.
...@@ -83,15 +80,13 @@ func (z *Rat) Num() *Int { ...@@ -83,15 +80,13 @@ func (z *Rat) Num() *Int {
return &z.a return &z.a
} }
// Denom returns the denominator of z; it is always > 0.
// Demom returns the denominator of z; it is always > 0.
// The result is a reference to z's denominator; it // The result is a reference to z's denominator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to z.
func (z *Rat) Denom() *Int { func (z *Rat) Denom() *Int {
return &Int{false, z.b} return &Int{false, z.b}
} }
func gcd(x, y nat) nat { func gcd(x, y nat) nat {
// Euclidean algorithm. // Euclidean algorithm.
var a, b nat var a, b nat
...@@ -106,7 +101,6 @@ func gcd(x, y nat) nat { ...@@ -106,7 +101,6 @@ func gcd(x, y nat) nat {
return a return a
} }
func (z *Rat) norm() *Rat { func (z *Rat) norm() *Rat {
f := gcd(z.a.abs, z.b) f := gcd(z.a.abs, z.b)
if len(z.a.abs) == 0 { if len(z.a.abs) == 0 {
...@@ -122,7 +116,6 @@ func (z *Rat) norm() *Rat { ...@@ -122,7 +116,6 @@ func (z *Rat) norm() *Rat {
return z return z
} }
func mulNat(x *Int, y nat) *Int { func mulNat(x *Int, y nat) *Int {
var z Int var z Int
z.abs = z.abs.mul(x.abs, y) z.abs = z.abs.mul(x.abs, y)
...@@ -130,7 +123,6 @@ func mulNat(x *Int, y nat) *Int { ...@@ -130,7 +123,6 @@ func mulNat(x *Int, y nat) *Int {
return &z return &z
} }
// Cmp compares x and y and returns: // Cmp compares x and y and returns:
// //
// -1 if x < y // -1 if x < y
...@@ -141,7 +133,6 @@ func (x *Rat) Cmp(y *Rat) (r int) { ...@@ -141,7 +133,6 @@ func (x *Rat) Cmp(y *Rat) (r int) {
return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b)) return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b))
} }
// Abs sets z to |x| (the absolute value of x) and returns z. // Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat { func (z *Rat) Abs(x *Rat) *Rat {
z.a.Abs(&x.a) z.a.Abs(&x.a)
...@@ -149,7 +140,6 @@ func (z *Rat) Abs(x *Rat) *Rat { ...@@ -149,7 +140,6 @@ func (z *Rat) Abs(x *Rat) *Rat {
return z return z
} }
// Add sets z to the sum x+y and returns z. // Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat { func (z *Rat) Add(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := mulNat(&x.a, y.b)
...@@ -159,7 +149,6 @@ func (z *Rat) Add(x, y *Rat) *Rat { ...@@ -159,7 +149,6 @@ func (z *Rat) Add(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Sub sets z to the difference x-y and returns z. // Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat { func (z *Rat) Sub(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := mulNat(&x.a, y.b)
...@@ -169,7 +158,6 @@ func (z *Rat) Sub(x, y *Rat) *Rat { ...@@ -169,7 +158,6 @@ func (z *Rat) Sub(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Mul sets z to the product x*y and returns z. // Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat { func (z *Rat) Mul(x, y *Rat) *Rat {
z.a.Mul(&x.a, &y.a) z.a.Mul(&x.a, &y.a)
...@@ -177,7 +165,6 @@ func (z *Rat) Mul(x, y *Rat) *Rat { ...@@ -177,7 +165,6 @@ func (z *Rat) Mul(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Quo sets z to the quotient x/y and returns z. // Quo sets z to the quotient x/y and returns z.
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
func (z *Rat) Quo(x, y *Rat) *Rat { func (z *Rat) Quo(x, y *Rat) *Rat {
...@@ -192,7 +179,6 @@ func (z *Rat) Quo(x, y *Rat) *Rat { ...@@ -192,7 +179,6 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
return z.norm() return z.norm()
} }
// Neg sets z to -x (by making a copy of x if necessary) and returns z. // Neg sets z to -x (by making a copy of x if necessary) and returns z.
func (z *Rat) Neg(x *Rat) *Rat { func (z *Rat) Neg(x *Rat) *Rat {
z.a.Neg(&x.a) z.a.Neg(&x.a)
...@@ -200,7 +186,6 @@ func (z *Rat) Neg(x *Rat) *Rat { ...@@ -200,7 +186,6 @@ func (z *Rat) Neg(x *Rat) *Rat {
return z return z
} }
// Set sets z to x (by making a copy of x if necessary) and returns z. // Set sets z to x (by making a copy of x if necessary) and returns z.
func (z *Rat) Set(x *Rat) *Rat { func (z *Rat) Set(x *Rat) *Rat {
z.a.Set(&x.a) z.a.Set(&x.a)
...@@ -208,6 +193,25 @@ func (z *Rat) Set(x *Rat) *Rat { ...@@ -208,6 +193,25 @@ func (z *Rat) Set(x *Rat) *Rat {
return z return z
} }
func ratTok(ch int) bool {
return strings.IndexRune("+-/0123456789.eE", ch) >= 0
}
// Scan is a support routine for fmt.Scanner. It accepts the formats
// 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent.
func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error {
tok, err := s.Token(true, ratTok)
if err != nil {
return err
}
if strings.IndexRune("efgEFGv", ch) < 0 {
return os.NewError("Rat.Scan: invalid verb")
}
if _, ok := z.SetString(string(tok)); !ok {
return os.NewError("Rat.Scan: invalid syntax")
}
return nil
}
// SetString sets z to the value of s and returns z and a boolean indicating // SetString sets z to the value of s and returns z and a boolean indicating
// success. s can be given as a fraction "a/b" or as a floating-point number // success. s can be given as a fraction "a/b" or as a floating-point number
...@@ -225,8 +229,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -225,8 +229,8 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, false return z, false
} }
s = s[sep+1:] s = s[sep+1:]
var n int var err os.Error
if z.b, _, n = z.b.scan(s, 10); n != len(s) { if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
return z, false return z, false
} }
return z.norm(), true return z.norm(), true
...@@ -267,13 +271,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -267,13 +271,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
return z, true return z, true
} }
// String returns a string representation of z in the form "a/b" (even if b == 1). // String returns a string representation of z in the form "a/b" (even if b == 1).
func (z *Rat) String() string { func (z *Rat) String() string {
return z.a.String() + "/" + z.b.string(10) return z.a.String() + "/" + z.b.decimalString()
} }
// RatString returns a string representation of z in the form "a/b" if b != 1, // RatString returns a string representation of z in the form "a/b" if b != 1,
// and in the form "a" if b == 1. // and in the form "a" if b == 1.
func (z *Rat) RatString() string { func (z *Rat) RatString() string {
...@@ -283,12 +285,15 @@ func (z *Rat) RatString() string { ...@@ -283,12 +285,15 @@ func (z *Rat) RatString() string {
return z.String() return z.String()
} }
// FloatString returns a string representation of z in decimal form with prec // FloatString returns a string representation of z in decimal form with prec
// digits of precision after the decimal point and the last digit rounded. // digits of precision after the decimal point and the last digit rounded.
func (z *Rat) FloatString(prec int) string { func (z *Rat) FloatString(prec int) string {
if z.IsInt() { if z.IsInt() {
return z.a.String() s := z.a.String()
if prec > 0 {
s += "." + strings.Repeat("0", prec)
}
return s
} }
q, r := nat{}.div(nat{}, z.a.abs, z.b) q, r := nat{}.div(nat{}, z.a.abs, z.b)
...@@ -311,16 +316,56 @@ func (z *Rat) FloatString(prec int) string { ...@@ -311,16 +316,56 @@ func (z *Rat) FloatString(prec int) string {
} }
} }
s := q.string(10) s := q.decimalString()
if z.a.neg { if z.a.neg {
s = "-" + s s = "-" + s
} }
if prec > 0 { if prec > 0 {
rs := r.string(10) rs := r.decimalString()
leadingZeros := prec - len(rs) leadingZeros := prec - len(rs)
s += "." + strings.Repeat("0", leadingZeros) + rs s += "." + strings.Repeat("0", leadingZeros) + rs
} }
return s return s
} }
// Gob codec version. Permits backward-compatible changes to the encoding.
const ratGobVersion byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Rat) GobEncode() ([]byte, os.Error) {
buf := make([]byte, 1+4+(len(z.a.abs)+len(z.b))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
i := z.b.bytes(buf)
j := z.a.abs.bytes(buf[0:i])
n := i - j
if int(uint32(n)) != n {
// this should never happen
return nil, os.NewError("Rat.GobEncode: numerator too large")
}
binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
j -= 1 + 4
b := ratGobVersion << 1 // make space for sign bit
if z.a.neg {
b |= 1
}
buf[j] = b
return buf[j:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Rat) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Rat.GobDecode: no data")
}
b := buf[0]
if b>>1 != ratGobVersion {
return os.NewError(fmt.Sprintf("Rat.GobDecode: encoding version %d not supported", b>>1))
}
const j = 1 + 4
i := j + binary.BigEndian.Uint32(buf[j-4:j])
z.a.neg = b&1 != 0
z.a.abs = z.a.abs.setBytes(buf[j:i])
z.b = z.b.setBytes(buf[i:])
return nil
}
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
package big package big
import "testing" import (
"bytes"
"fmt"
"gob"
"testing"
)
var setStringTests = []struct { var setStringTests = []struct {
in, out string in, out string
...@@ -52,6 +56,27 @@ func TestRatSetString(t *testing.T) { ...@@ -52,6 +56,27 @@ func TestRatSetString(t *testing.T) {
} }
} }
func TestRatScan(t *testing.T) {
var buf bytes.Buffer
for i, test := range setStringTests {
x := new(Rat)
buf.Reset()
buf.WriteString(test.in)
_, err := fmt.Fscanf(&buf, "%v", x)
if err == nil != test.ok {
if test.ok {
t.Errorf("#%d error: %s", i, err.String())
} else {
t.Errorf("#%d expected error", i)
}
continue
}
if err == nil && x.RatString() != test.out {
t.Errorf("#%d got %s want %s", i, x.RatString(), test.out)
}
}
}
var floatStringTests = []struct { var floatStringTests = []struct {
in string in string
...@@ -59,12 +84,13 @@ var floatStringTests = []struct { ...@@ -59,12 +84,13 @@ var floatStringTests = []struct {
out string out string
}{ }{
{"0", 0, "0"}, {"0", 0, "0"},
{"0", 4, "0"}, {"0", 4, "0.0000"},
{"1", 0, "1"}, {"1", 0, "1"},
{"1", 2, "1"}, {"1", 2, "1.00"},
{"-1", 0, "-1"}, {"-1", 0, "-1"},
{".25", 2, "0.25"}, {".25", 2, "0.25"},
{".25", 1, "0.3"}, {".25", 1, "0.3"},
{".25", 3, "0.250"},
{"-1/3", 3, "-0.333"}, {"-1/3", 3, "-0.333"},
{"-2/3", 4, "-0.6667"}, {"-2/3", 4, "-0.6667"},
{"0.96", 1, "1.0"}, {"0.96", 1, "1.0"},
...@@ -84,7 +110,6 @@ func TestFloatString(t *testing.T) { ...@@ -84,7 +110,6 @@ func TestFloatString(t *testing.T) {
} }
} }
func TestRatSign(t *testing.T) { func TestRatSign(t *testing.T) {
zero := NewRat(0, 1) zero := NewRat(0, 1)
for _, a := range setStringTests { for _, a := range setStringTests {
...@@ -98,7 +123,6 @@ func TestRatSign(t *testing.T) { ...@@ -98,7 +123,6 @@ func TestRatSign(t *testing.T) {
} }
} }
var ratCmpTests = []struct { var ratCmpTests = []struct {
rat1, rat2 string rat1, rat2 string
out int out int
...@@ -126,7 +150,6 @@ func TestRatCmp(t *testing.T) { ...@@ -126,7 +150,6 @@ func TestRatCmp(t *testing.T) {
} }
} }
func TestIsInt(t *testing.T) { func TestIsInt(t *testing.T) {
one := NewInt(1) one := NewInt(1)
for _, a := range setStringTests { for _, a := range setStringTests {
...@@ -140,7 +163,6 @@ func TestIsInt(t *testing.T) { ...@@ -140,7 +163,6 @@ func TestIsInt(t *testing.T) {
} }
} }
func TestRatAbs(t *testing.T) { func TestRatAbs(t *testing.T) {
zero := NewRat(0, 1) zero := NewRat(0, 1)
for _, a := range setStringTests { for _, a := range setStringTests {
...@@ -158,7 +180,6 @@ func TestRatAbs(t *testing.T) { ...@@ -158,7 +180,6 @@ func TestRatAbs(t *testing.T) {
} }
} }
type ratBinFun func(z, x, y *Rat) *Rat type ratBinFun func(z, x, y *Rat) *Rat
type ratBinArg struct { type ratBinArg struct {
x, y, z string x, y, z string
...@@ -175,7 +196,6 @@ func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) { ...@@ -175,7 +196,6 @@ func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) {
} }
} }
var ratBinTests = []struct { var ratBinTests = []struct {
x, y string x, y string
sum, prod string sum, prod string
...@@ -232,7 +252,6 @@ func TestRatBin(t *testing.T) { ...@@ -232,7 +252,6 @@ func TestRatBin(t *testing.T) {
} }
} }
func TestIssue820(t *testing.T) { func TestIssue820(t *testing.T) {
x := NewRat(3, 1) x := NewRat(3, 1)
y := NewRat(2, 1) y := NewRat(2, 1)
...@@ -258,7 +277,6 @@ func TestIssue820(t *testing.T) { ...@@ -258,7 +277,6 @@ func TestIssue820(t *testing.T) {
} }
} }
var setFrac64Tests = []struct { var setFrac64Tests = []struct {
a, b int64 a, b int64
out string out string
...@@ -280,3 +298,35 @@ func TestRatSetFrac64Rat(t *testing.T) { ...@@ -280,3 +298,35 @@ func TestRatSetFrac64Rat(t *testing.T) {
} }
} }
} }
func TestRatGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for i, test := range gobEncodingTests {
for j := 0; j < 4; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j&1 != 0 {
// negative numbers
stest = "-" + test
}
if j%2 != 0 {
// fractions
stest = stest + "." + test
}
var tx Rat
tx.SetString(stest)
if err := enc.Encode(&tx); err != nil {
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
}
var rx Rat
if err := dec.Decode(&rx); err != nil {
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
}
if rx.Cmp(&tx) != 0 {
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
}
}
}
}
...@@ -15,16 +15,17 @@ import ( ...@@ -15,16 +15,17 @@ import (
"utf8" "utf8"
) )
const ( const (
defaultBufSize = 4096 defaultBufSize = 4096
) )
// Errors introduced by this package. // Errors introduced by this package.
type Error struct { type Error struct {
os.ErrorString ErrorString string
} }
func (err *Error) String() string { return err.ErrorString }
var ( var (
ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"} ErrInvalidUnreadByte os.Error = &Error{"bufio: invalid use of UnreadByte"}
ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"} ErrInvalidUnreadRune os.Error = &Error{"bufio: invalid use of UnreadRune"}
...@@ -40,7 +41,6 @@ func (b BufSizeError) String() string { ...@@ -40,7 +41,6 @@ func (b BufSizeError) String() string {
return "bufio: bad buffer size " + strconv.Itoa(int(b)) return "bufio: bad buffer size " + strconv.Itoa(int(b))
} }
// Buffered input. // Buffered input.
// Reader implements buffering for an io.Reader object. // Reader implements buffering for an io.Reader object.
...@@ -101,6 +101,12 @@ func (b *Reader) fill() { ...@@ -101,6 +101,12 @@ func (b *Reader) fill() {
} }
} }
func (b *Reader) readErr() os.Error {
err := b.err
b.err = nil
return err
}
// Peek returns the next n bytes without advancing the reader. The bytes stop // Peek returns the next n bytes without advancing the reader. The bytes stop
// being valid at the next read call. If Peek returns fewer than n bytes, it // being valid at the next read call. If Peek returns fewer than n bytes, it
// also returns an error explaining why the read is short. The error is // also returns an error explaining why the read is short. The error is
...@@ -119,7 +125,7 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) { ...@@ -119,7 +125,7 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
if m > n { if m > n {
m = n m = n
} }
err := b.err err := b.readErr()
if m < n && err == nil { if m < n && err == nil {
err = ErrBufferFull err = ErrBufferFull
} }
...@@ -134,11 +140,11 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) { ...@@ -134,11 +140,11 @@ func (b *Reader) Peek(n int) ([]byte, os.Error) {
func (b *Reader) Read(p []byte) (n int, err os.Error) { func (b *Reader) Read(p []byte) (n int, err os.Error) {
n = len(p) n = len(p)
if n == 0 { if n == 0 {
return 0, b.err return 0, b.readErr()
} }
if b.w == b.r { if b.w == b.r {
if b.err != nil { if b.err != nil {
return 0, b.err return 0, b.readErr()
} }
if len(p) >= len(b.buf) { if len(p) >= len(b.buf) {
// Large read, empty buffer. // Large read, empty buffer.
...@@ -148,11 +154,11 @@ func (b *Reader) Read(p []byte) (n int, err os.Error) { ...@@ -148,11 +154,11 @@ func (b *Reader) Read(p []byte) (n int, err os.Error) {
b.lastByte = int(p[n-1]) b.lastByte = int(p[n-1])
b.lastRuneSize = -1 b.lastRuneSize = -1
} }
return n, b.err return n, b.readErr()
} }
b.fill() b.fill()
if b.w == b.r { if b.w == b.r {
return 0, b.err return 0, b.readErr()
} }
} }
...@@ -172,7 +178,7 @@ func (b *Reader) ReadByte() (c byte, err os.Error) { ...@@ -172,7 +178,7 @@ func (b *Reader) ReadByte() (c byte, err os.Error) {
b.lastRuneSize = -1 b.lastRuneSize = -1
for b.w == b.r { for b.w == b.r {
if b.err != nil { if b.err != nil {
return 0, b.err return 0, b.readErr()
} }
b.fill() b.fill()
} }
...@@ -208,7 +214,7 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) { ...@@ -208,7 +214,7 @@ func (b *Reader) ReadRune() (rune int, size int, err os.Error) {
} }
b.lastRuneSize = -1 b.lastRuneSize = -1
if b.r == b.w { if b.r == b.w {
return 0, 0, b.err return 0, 0, b.readErr()
} }
rune, size = int(b.buf[b.r]), 1 rune, size = int(b.buf[b.r]), 1
if rune >= 0x80 { if rune >= 0x80 {
...@@ -260,7 +266,7 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) { ...@@ -260,7 +266,7 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err os.Error) {
if b.err != nil { if b.err != nil {
line := b.buf[b.r:b.w] line := b.buf[b.r:b.w]
b.r = b.w b.r = b.w
return line, b.err return line, b.readErr()
} }
n := b.Buffered() n := b.Buffered()
...@@ -367,7 +373,6 @@ func (b *Reader) ReadString(delim byte) (line string, err os.Error) { ...@@ -367,7 +373,6 @@ func (b *Reader) ReadString(delim byte) (line string, err os.Error) {
return string(bytes), e return string(bytes), e
} }
// buffered output // buffered output
// Writer implements buffering for an io.Writer object. // Writer implements buffering for an io.Writer object.
......
...@@ -53,11 +53,12 @@ func readBytes(buf *Reader) string { ...@@ -53,11 +53,12 @@ func readBytes(buf *Reader) string {
if e == os.EOF { if e == os.EOF {
break break
} }
if e != nil { if e == nil {
b[nb] = c
nb++
} else if e != iotest.ErrTimeout {
panic("Data: " + e.String()) panic("Data: " + e.String())
} }
b[nb] = c
nb++
} }
return string(b[0:nb]) return string(b[0:nb])
} }
...@@ -75,7 +76,6 @@ func TestReaderSimple(t *testing.T) { ...@@ -75,7 +76,6 @@ func TestReaderSimple(t *testing.T) {
} }
} }
type readMaker struct { type readMaker struct {
name string name string
fn func(io.Reader) io.Reader fn func(io.Reader) io.Reader
...@@ -86,6 +86,7 @@ var readMakers = []readMaker{ ...@@ -86,6 +86,7 @@ var readMakers = []readMaker{
{"byte", iotest.OneByteReader}, {"byte", iotest.OneByteReader},
{"half", iotest.HalfReader}, {"half", iotest.HalfReader},
{"data+err", iotest.DataErrReader}, {"data+err", iotest.DataErrReader},
{"timeout", iotest.TimeoutReader},
} }
// Call ReadString (which ends up calling everything else) // Call ReadString (which ends up calling everything else)
...@@ -97,7 +98,7 @@ func readLines(b *Reader) string { ...@@ -97,7 +98,7 @@ func readLines(b *Reader) string {
if e == os.EOF { if e == os.EOF {
break break
} }
if e != nil { if e != nil && e != iotest.ErrTimeout {
panic("GetLines: " + e.String()) panic("GetLines: " + e.String())
} }
s += s1 s += s1
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package builtin provides documentation for Go's built-in functions.
The functions documented here are not actually in package builtin
but their descriptions here allow godoc to present documentation
for the language's special functions.
*/
package builtin
// Type is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type int
// IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc.
type IntegerType int
// FloatType is here for the purposes of documentation only. It is a stand-in
// for either float type: float32 or float64.
type FloatType int
// ComplexType is here for the purposes of documentation only. It is a
// stand-in for either complex type: complex64 or complex128.
type ComplexType int
// The append built-in function appends elements to the end of a slice. If
// it has sufficient capacity, the destination is resliced to accommodate the
// new elements. If it does not, a new underlying array will be allocated.
// Append returns the updated slice. It is therefore necessary to store the
// result of append, often in the variable holding the slice itself:
// slice = append(slice, elem1, elem2)
// slice = append(slice, anotherSlice...)
func append(slice []Type, elems ...Type) []Type
// The copy built-in function copies elements from a source slice into a
// destination slice. (As a special case, it also will copy bytes from a
// string to a slice of bytes.) The source and destination may overlap. Copy
// returns the number of elements copied, which will be the minimum of
// len(src) and len(dst).
func copy(dst, src []Type) int
// The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil).
// Slice, or map: the number of elements in v; if v is nil, len(v) is zero.
// String: the number of bytes in v.
// Channel: the number of elements queued (unread) in the channel buffer;
// if v is nil, len(v) is zero.
func len(v Type) int
// The cap built-in function returns the capacity of v, according to its type:
// Array: the number of elements in v (same as len(v)).
// Pointer to array: the number of elements in *v (same as len(v)).
// Slice: the maximum length the slice can reach when resliced;
// if v is nil, cap(v) is zero.
// Channel: the channel buffer capacity, in units of elements;
// if v is nil, cap(v) is zero.
func cap(v Type) int
// The make built-in function allocates and initializes an object of type
// slice, map, or chan (only). Like new, the first argument is a type, not a
// value. Unlike new, make's return type is the same as the type of its
// argument, not a pointer to it. The specification of the result depends on
// the type:
// Slice: The size specifies the length. The capacity of the slice is
// equal to its length. A second integer argument may be provided to
// specify a different capacity; it must be no smaller than the
// length, so make([]int, 0, 10) allocates a slice of length 0 and
// capacity 10.
// Map: An initial allocation is made according to the size but the
// resulting map has length 0. The size may be omitted, in which case
// a small starting size is allocated.
// Channel: The channel's buffer is initialized with the specified
// buffer capacity. If zero, or the size is omitted, the channel is
// unbuffered.
func make(Type, size IntegerType) Type
// The new built-in function allocates memory. The first argument is a type,
// not a value, and the value returned is a pointer to a newly
// allocated zero value of that type.
func new(Type) *Type
// The complex built-in function constructs a complex value from two
// floating-point values. The real and imaginary parts must be of the same
// size, either float32 or float64 (or assignable to them), and the return
// value will be the corresponding complex type (complex64 for float32,
// complex128 for float64).
func complex(r, i FloatType) ComplexType
// The real built-in function returns the real part of the complex number c.
// The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to
// the type of c.
func imag(c ComplexType) FloatType
// The close built-in function closes a channel, which must be either
// bidirectional or send-only. It should be executed only by the sender,
// never the receiver, and has the effect of shutting down the channel after
// the last sent value is received. After the last value has been received
// from a closed channel c, any receive from c will succeed without
// blocking, returning the zero value for the channel element. The form
// x, ok := <-c
// will also set ok to false for a closed channel.
func close(c chan<- Type)
// The panic built-in function stops normal execution of the current
// goroutine. When a function F calls panic, normal execution of F stops
// immediately. Any functions whose execution was deferred by F are run in
// the usual way, and then F returns to its caller. To the caller G, the
// invocation of F then behaves like a call to panic, terminating G's
// execution and running any deferred functions. This continues until all
// functions in the executing goroutine have stopped, in reverse order. At
// that point, the program is terminated and the error condition is reported,
// including the value of the argument to panic. This termination sequence
// is called panicking and can be controlled by the built-in function
// recover.
func panic(v interface{})
// The recover built-in function allows a program to manage behavior of a
// panicking goroutine. Executing a call to recover inside a deferred
// function (but not any function called by it) stops the panicking sequence
// by restoring normal execution and retrieves the error value passed to the
// call of panic. If recover is called outside the deferred function it will
// not stop a panicking sequence. In this case, or when the goroutine is not
// panicking, or if the argument supplied to panic was nil, recover returns
// nil. Thus the return value from recover reports whether the goroutine is
// panicking.
func recover() interface{}
...@@ -280,7 +280,7 @@ func (b *Buffer) ReadRune() (r int, size int, err os.Error) { ...@@ -280,7 +280,7 @@ func (b *Buffer) ReadRune() (r int, size int, err os.Error) {
// from any read operation.) // from any read operation.)
func (b *Buffer) UnreadRune() os.Error { func (b *Buffer) UnreadRune() os.Error {
if b.lastRead != opReadRune { if b.lastRead != opReadRune {
return os.ErrorString("bytes.Buffer: UnreadRune: previous operation was not ReadRune") return os.NewError("bytes.Buffer: UnreadRune: previous operation was not ReadRune")
} }
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off > 0 { if b.off > 0 {
...@@ -295,7 +295,7 @@ func (b *Buffer) UnreadRune() os.Error { ...@@ -295,7 +295,7 @@ func (b *Buffer) UnreadRune() os.Error {
// returns an error. // returns an error.
func (b *Buffer) UnreadByte() os.Error { func (b *Buffer) UnreadByte() os.Error {
if b.lastRead != opReadRune && b.lastRead != opRead { if b.lastRead != opReadRune && b.lastRead != opRead {
return os.ErrorString("bytes.Buffer: UnreadByte: previous operation was not a read") return os.NewError("bytes.Buffer: UnreadByte: previous operation was not a read")
} }
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off > 0 { if b.off > 0 {
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"utf8" "utf8"
) )
const N = 10000 // make this bigger for a larger (and slower) test const N = 10000 // make this bigger for a larger (and slower) test
var data string // test data for write tests var data string // test data for write tests
var bytes []byte // test data; same as data but as a slice. var bytes []byte // test data; same as data but as a slice.
...@@ -47,7 +46,6 @@ func check(t *testing.T, testname string, buf *Buffer, s string) { ...@@ -47,7 +46,6 @@ func check(t *testing.T, testname string, buf *Buffer, s string) {
} }
} }
// Fill buf through n writes of string fus. // Fill buf through n writes of string fus.
// The initial contents of buf corresponds to the string s; // The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string. // the result is the final contents of buf returned as a string.
...@@ -67,7 +65,6 @@ func fillString(t *testing.T, testname string, buf *Buffer, s string, n int, fus ...@@ -67,7 +65,6 @@ func fillString(t *testing.T, testname string, buf *Buffer, s string, n int, fus
return s return s
} }
// Fill buf through n writes of byte slice fub. // Fill buf through n writes of byte slice fub.
// The initial contents of buf corresponds to the string s; // The initial contents of buf corresponds to the string s;
// the result is the final contents of buf returned as a string. // the result is the final contents of buf returned as a string.
...@@ -87,19 +84,16 @@ func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub ...@@ -87,19 +84,16 @@ func fillBytes(t *testing.T, testname string, buf *Buffer, s string, n int, fub
return s return s
} }
func TestNewBuffer(t *testing.T) { func TestNewBuffer(t *testing.T) {
buf := NewBuffer(bytes) buf := NewBuffer(bytes)
check(t, "NewBuffer", buf, data) check(t, "NewBuffer", buf, data)
} }
func TestNewBufferString(t *testing.T) { func TestNewBufferString(t *testing.T) {
buf := NewBufferString(data) buf := NewBufferString(data)
check(t, "NewBufferString", buf, data) check(t, "NewBufferString", buf, data)
} }
// Empty buf through repeated reads into fub. // Empty buf through repeated reads into fub.
// The initial contents of buf corresponds to the string s. // The initial contents of buf corresponds to the string s.
func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) { func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
...@@ -120,7 +114,6 @@ func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) { ...@@ -120,7 +114,6 @@ func empty(t *testing.T, testname string, buf *Buffer, s string, fub []byte) {
check(t, testname+" (empty 4)", buf, "") check(t, testname+" (empty 4)", buf, "")
} }
func TestBasicOperations(t *testing.T) { func TestBasicOperations(t *testing.T) {
var buf Buffer var buf Buffer
...@@ -175,7 +168,6 @@ func TestBasicOperations(t *testing.T) { ...@@ -175,7 +168,6 @@ func TestBasicOperations(t *testing.T) {
} }
} }
func TestLargeStringWrites(t *testing.T) { func TestLargeStringWrites(t *testing.T) {
var buf Buffer var buf Buffer
limit := 30 limit := 30
...@@ -189,7 +181,6 @@ func TestLargeStringWrites(t *testing.T) { ...@@ -189,7 +181,6 @@ func TestLargeStringWrites(t *testing.T) {
check(t, "TestLargeStringWrites (3)", &buf, "") check(t, "TestLargeStringWrites (3)", &buf, "")
} }
func TestLargeByteWrites(t *testing.T) { func TestLargeByteWrites(t *testing.T) {
var buf Buffer var buf Buffer
limit := 30 limit := 30
...@@ -203,7 +194,6 @@ func TestLargeByteWrites(t *testing.T) { ...@@ -203,7 +194,6 @@ func TestLargeByteWrites(t *testing.T) {
check(t, "TestLargeByteWrites (3)", &buf, "") check(t, "TestLargeByteWrites (3)", &buf, "")
} }
func TestLargeStringReads(t *testing.T) { func TestLargeStringReads(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
...@@ -213,7 +203,6 @@ func TestLargeStringReads(t *testing.T) { ...@@ -213,7 +203,6 @@ func TestLargeStringReads(t *testing.T) {
check(t, "TestLargeStringReads (3)", &buf, "") check(t, "TestLargeStringReads (3)", &buf, "")
} }
func TestLargeByteReads(t *testing.T) { func TestLargeByteReads(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
...@@ -223,7 +212,6 @@ func TestLargeByteReads(t *testing.T) { ...@@ -223,7 +212,6 @@ func TestLargeByteReads(t *testing.T) {
check(t, "TestLargeByteReads (3)", &buf, "") check(t, "TestLargeByteReads (3)", &buf, "")
} }
func TestMixedReadsAndWrites(t *testing.T) { func TestMixedReadsAndWrites(t *testing.T) {
var buf Buffer var buf Buffer
s := "" s := ""
...@@ -243,7 +231,6 @@ func TestMixedReadsAndWrites(t *testing.T) { ...@@ -243,7 +231,6 @@ func TestMixedReadsAndWrites(t *testing.T) {
empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len())) empty(t, "TestMixedReadsAndWrites (2)", &buf, s, make([]byte, buf.Len()))
} }
func TestNil(t *testing.T) { func TestNil(t *testing.T) {
var b *Buffer var b *Buffer
if b.String() != "<nil>" { if b.String() != "<nil>" {
...@@ -251,7 +238,6 @@ func TestNil(t *testing.T) { ...@@ -251,7 +238,6 @@ func TestNil(t *testing.T) {
} }
} }
func TestReadFrom(t *testing.T) { func TestReadFrom(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
...@@ -262,7 +248,6 @@ func TestReadFrom(t *testing.T) { ...@@ -262,7 +248,6 @@ func TestReadFrom(t *testing.T) {
} }
} }
func TestWriteTo(t *testing.T) { func TestWriteTo(t *testing.T) {
var buf Buffer var buf Buffer
for i := 3; i < 30; i += 3 { for i := 3; i < 30; i += 3 {
...@@ -273,7 +258,6 @@ func TestWriteTo(t *testing.T) { ...@@ -273,7 +258,6 @@ func TestWriteTo(t *testing.T) {
} }
} }
func TestRuneIO(t *testing.T) { func TestRuneIO(t *testing.T) {
const NRune = 1000 const NRune = 1000
// Built a test array while we write the data // Built a test array while we write the data
...@@ -323,7 +307,6 @@ func TestRuneIO(t *testing.T) { ...@@ -323,7 +307,6 @@ func TestRuneIO(t *testing.T) {
} }
} }
func TestNext(t *testing.T) { func TestNext(t *testing.T) {
b := []byte{0, 1, 2, 3, 4} b := []byte{0, 1, 2, 3, 4}
tmp := make([]byte, 5) tmp := make([]byte, 5)
......
...@@ -212,26 +212,40 @@ func genSplit(s, sep []byte, sepSave, n int) [][]byte { ...@@ -212,26 +212,40 @@ func genSplit(s, sep []byte, sepSave, n int) [][]byte {
return a[0 : na+1] return a[0 : na+1]
} }
// Split slices s into subslices separated by sep and returns a slice of // SplitN slices s into subslices separated by sep and returns a slice of
// the subslices between those separators. // the subslices between those separators.
// If sep is empty, Split splits after each UTF-8 sequence. // If sep is empty, SplitN splits after each UTF-8 sequence.
// The count determines the number of subslices to return: // The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder. // n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices) // n == 0: the result is nil (zero subslices)
// n < 0: all subslices // n < 0: all subslices
func Split(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) } func SplitN(s, sep []byte, n int) [][]byte { return genSplit(s, sep, 0, n) }
// SplitAfter slices s into subslices after each instance of sep and // SplitAfterN slices s into subslices after each instance of sep and
// returns a slice of those subslices. // returns a slice of those subslices.
// If sep is empty, Split splits after each UTF-8 sequence. // If sep is empty, SplitAfterN splits after each UTF-8 sequence.
// The count determines the number of subslices to return: // The count determines the number of subslices to return:
// n > 0: at most n subslices; the last subslice will be the unsplit remainder. // n > 0: at most n subslices; the last subslice will be the unsplit remainder.
// n == 0: the result is nil (zero subslices) // n == 0: the result is nil (zero subslices)
// n < 0: all subslices // n < 0: all subslices
func SplitAfter(s, sep []byte, n int) [][]byte { func SplitAfterN(s, sep []byte, n int) [][]byte {
return genSplit(s, sep, len(sep), n) return genSplit(s, sep, len(sep), n)
} }
// Split slices s into all subslices separated by sep and returns a slice of
// the subslices between those separators.
// If sep is empty, Split splits after each UTF-8 sequence.
// It is equivalent to SplitN with a count of -1.
func Split(s, sep []byte) [][]byte { return genSplit(s, sep, 0, -1) }
// SplitAfter slices s into all subslices after each instance of sep and
// returns a slice of those subslices.
// If sep is empty, SplitAfter splits after each UTF-8 sequence.
// It is equivalent to SplitAfterN with a count of -1.
func SplitAfter(s, sep []byte) [][]byte {
return genSplit(s, sep, len(sep), -1)
}
// Fields splits the array s around each instance of one or more consecutive white space // Fields splits the array s around each instance of one or more consecutive white space
// characters, returning a slice of subarrays of s or an empty list if s contains only white space. // characters, returning a slice of subarrays of s or an empty list if s contains only white space.
func Fields(s []byte) [][]byte { func Fields(s []byte) [][]byte {
...@@ -384,7 +398,6 @@ func ToTitleSpecial(_case unicode.SpecialCase, s []byte) []byte { ...@@ -384,7 +398,6 @@ func ToTitleSpecial(_case unicode.SpecialCase, s []byte) []byte {
return Map(func(r int) int { return _case.ToTitle(r) }, s) return Map(func(r int) int { return _case.ToTitle(r) }, s)
} }
// isSeparator reports whether the rune could mark a word boundary. // isSeparator reports whether the rune could mark a word boundary.
// TODO: update when package unicode captures more of the properties. // TODO: update when package unicode captures more of the properties.
func isSeparator(rune int) bool { func isSeparator(rune int) bool {
......
...@@ -6,6 +6,7 @@ package bytes_test ...@@ -6,6 +6,7 @@ package bytes_test
import ( import (
. "bytes" . "bytes"
"reflect"
"testing" "testing"
"unicode" "unicode"
"utf8" "utf8"
...@@ -315,7 +316,7 @@ var explodetests = []ExplodeTest{ ...@@ -315,7 +316,7 @@ var explodetests = []ExplodeTest{
func TestExplode(t *testing.T) { func TestExplode(t *testing.T) {
for _, tt := range explodetests { for _, tt := range explodetests {
a := Split([]byte(tt.s), nil, tt.n) a := SplitN([]byte(tt.s), nil, tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a) t.Errorf(`Explode("%s", %d) = %v; want %v`, tt.s, tt.n, result, tt.a)
...@@ -328,7 +329,6 @@ func TestExplode(t *testing.T) { ...@@ -328,7 +329,6 @@ func TestExplode(t *testing.T) {
} }
} }
type SplitTest struct { type SplitTest struct {
s string s string
sep string sep string
...@@ -354,7 +354,7 @@ var splittests = []SplitTest{ ...@@ -354,7 +354,7 @@ var splittests = []SplitTest{
func TestSplit(t *testing.T) { func TestSplit(t *testing.T) {
for _, tt := range splittests { for _, tt := range splittests {
a := Split([]byte(tt.s), []byte(tt.sep), tt.n) a := SplitN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a) t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
...@@ -367,6 +367,12 @@ func TestSplit(t *testing.T) { ...@@ -367,6 +367,12 @@ func TestSplit(t *testing.T) {
if string(s) != tt.s { if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s) t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
} }
if tt.n < 0 {
b := Split([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("Split disagrees withSplitN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
} }
} }
...@@ -388,7 +394,7 @@ var splitaftertests = []SplitTest{ ...@@ -388,7 +394,7 @@ var splitaftertests = []SplitTest{
func TestSplitAfter(t *testing.T) { func TestSplitAfter(t *testing.T) {
for _, tt := range splitaftertests { for _, tt := range splitaftertests {
a := SplitAfter([]byte(tt.s), []byte(tt.sep), tt.n) a := SplitAfterN([]byte(tt.s), []byte(tt.sep), tt.n)
result := arrayOfString(a) result := arrayOfString(a)
if !eq(result, tt.a) { if !eq(result, tt.a) {
t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a) t.Errorf(`Split(%q, %q, %d) = %v; want %v`, tt.s, tt.sep, tt.n, result, tt.a)
...@@ -398,6 +404,12 @@ func TestSplitAfter(t *testing.T) { ...@@ -398,6 +404,12 @@ func TestSplitAfter(t *testing.T) {
if string(s) != tt.s { if string(s) != tt.s {
t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s) t.Errorf(`Join(Split(%q, %q, %d), %q) = %q`, tt.s, tt.sep, tt.n, tt.sep, s)
} }
if tt.n < 0 {
b := SplitAfter([]byte(tt.s), []byte(tt.sep))
if !reflect.DeepEqual(a, b) {
t.Errorf("SplitAfter disagrees withSplitAfterN(%q, %q, %d) = %v; want %v", tt.s, tt.sep, tt.n, b, a)
}
}
} }
} }
...@@ -649,7 +661,6 @@ func TestRunes(t *testing.T) { ...@@ -649,7 +661,6 @@ func TestRunes(t *testing.T) {
} }
} }
type TrimTest struct { type TrimTest struct {
f func([]byte, string) []byte f func([]byte, string) []byte
in, cutset, out string in, cutset, out string
......
...@@ -284,7 +284,7 @@ func (bz2 *reader) readBlock() (err os.Error) { ...@@ -284,7 +284,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
repeat := 0 repeat := 0
repeat_power := 0 repeat_power := 0
// The `C' array (used by the inverse BWT) needs to be zero initialised. // The `C' array (used by the inverse BWT) needs to be zero initialized.
for i := range bz2.c { for i := range bz2.c {
bz2.c[i] = 0 bz2.c[i] = 0
} }
...@@ -330,7 +330,7 @@ func (bz2 *reader) readBlock() (err os.Error) { ...@@ -330,7 +330,7 @@ func (bz2 *reader) readBlock() (err os.Error) {
if int(v) == numSymbols-1 { if int(v) == numSymbols-1 {
// This is the EOF symbol. Because it's always at the // This is the EOF symbol. Because it's always at the
// end of the move-to-front list, and nevers gets moved // end of the move-to-front list, and never gets moved
// to the front, it has this unique value. // to the front, it has this unique value.
break break
} }
......
...@@ -68,7 +68,7 @@ func newHuffmanTree(lengths []uint8) (huffmanTree, os.Error) { ...@@ -68,7 +68,7 @@ func newHuffmanTree(lengths []uint8) (huffmanTree, os.Error) {
// each symbol (consider reflecting a tree down the middle, for // each symbol (consider reflecting a tree down the middle, for
// example). Since the code length assignments determine the // example). Since the code length assignments determine the
// efficiency of the tree, each of these trees is equally good. In // efficiency of the tree, each of these trees is equally good. In
// order to minimise the amount of information needed to build a tree // order to minimize the amount of information needed to build a tree
// bzip2 uses a canonical tree so that it can be reconstructed given // bzip2 uses a canonical tree so that it can be reconstructed given
// only the code length assignments. // only the code length assignments.
......
...@@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{ ...@@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
&deflateInflateTest{[]byte{0x11, 0x12}}, &deflateInflateTest{[]byte{0x11, 0x12}},
&deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}}, &deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
&deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}}, &deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
&deflateInflateTest{getLargeDataChunk()}, &deflateInflateTest{largeDataChunk()},
} }
var reverseBitsTests = []*reverseBitsTest{ var reverseBitsTests = []*reverseBitsTest{
...@@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{ ...@@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
&reverseBitsTest{29, 5, 23}, &reverseBitsTest{29, 5, 23},
} }
func getLargeDataChunk() []byte { func largeDataChunk() []byte {
result := make([]byte, 100000) result := make([]byte, 100000)
for i := range result { for i := range result {
result[i] = byte(int64(i) * int64(i) & 0xFF) result[i] = byte(i * i & 0xFF)
} }
return result return result
} }
func TestDeflate(t *testing.T) { func TestDeflate(t *testing.T) {
for _, h := range deflateTests { for _, h := range deflateTests {
buffer := bytes.NewBuffer(nil) var buf bytes.Buffer
w := NewWriter(buffer, h.level) w := NewWriter(&buf, h.level)
w.Write(h.in) w.Write(h.in)
w.Close() w.Close()
if bytes.Compare(buffer.Bytes(), h.out) != 0 { if !bytes.Equal(buf.Bytes(), h.out) {
t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v", t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
h.level, buffer.Bytes(), h.out)
} }
} }
} }
...@@ -226,7 +225,6 @@ func testSync(t *testing.T, level int, input []byte, name string) { ...@@ -226,7 +225,6 @@ func testSync(t *testing.T, level int, input []byte, name string) {
} }
} }
func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error { func testToFromWithLevel(t *testing.T, level int, input []byte, name string) os.Error {
buffer := bytes.NewBuffer(nil) buffer := bytes.NewBuffer(nil)
w := NewWriter(buffer, level) w := NewWriter(buffer, level)
......
...@@ -15,9 +15,6 @@ const ( ...@@ -15,9 +15,6 @@ const (
// The largest offset code. // The largest offset code.
offsetCodeCount = 30 offsetCodeCount = 30
// The largest offset code in the extensions.
extendedOffsetCodeCount = 42
// The special code used to mark the end of a block. // The special code used to mark the end of a block.
endBlockMarker = 256 endBlockMarker = 256
...@@ -100,11 +97,11 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { ...@@ -100,11 +97,11 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{ return &huffmanBitWriter{
w: w, w: w,
literalFreq: make([]int32, maxLit), literalFreq: make([]int32, maxLit),
offsetFreq: make([]int32, extendedOffsetCodeCount), offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxLit+extendedOffsetCodeCount+1), codegen: make([]uint8, maxLit+offsetCodeCount+1),
codegenFreq: make([]int32, codegenCodeCount), codegenFreq: make([]int32, codegenCodeCount),
literalEncoding: newHuffmanEncoder(maxLit), literalEncoding: newHuffmanEncoder(maxLit),
offsetEncoding: newHuffmanEncoder(extendedOffsetCodeCount), offsetEncoding: newHuffmanEncoder(offsetCodeCount),
codegenEncoding: newHuffmanEncoder(codegenCodeCount), codegenEncoding: newHuffmanEncoder(codegenCodeCount),
} }
} }
...@@ -185,7 +182,7 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) { ...@@ -185,7 +182,7 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
_, w.err = w.w.Write(bytes) _, w.err = w.w.Write(bytes)
} }
// RFC 1951 3.2.7 specifies a special run-length encoding for specifiying // RFC 1951 3.2.7 specifies a special run-length encoding for specifying
// the literal and offset lengths arrays (which are concatenated into a single // the literal and offset lengths arrays (which are concatenated into a single
// array). This method generates that run-length encoding. // array). This method generates that run-length encoding.
// //
...@@ -279,7 +276,7 @@ func (w *huffmanBitWriter) writeCode(code *huffmanEncoder, literal uint32) { ...@@ -279,7 +276,7 @@ func (w *huffmanBitWriter) writeCode(code *huffmanEncoder, literal uint32) {
// //
// numLiterals The number of literals specified in codegen // numLiterals The number of literals specified in codegen
// numOffsets The number of offsets specified in codegen // numOffsets The number of offsets specified in codegen
// numCodegens Tne number of codegens used in codegen // numCodegens The number of codegens used in codegen
func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) { func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) {
if w.err != nil { if w.err != nil {
return return
...@@ -290,13 +287,7 @@ func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, n ...@@ -290,13 +287,7 @@ func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, n
} }
w.writeBits(firstBits, 3) w.writeBits(firstBits, 3)
w.writeBits(int32(numLiterals-257), 5) w.writeBits(int32(numLiterals-257), 5)
if numOffsets > offsetCodeCount { w.writeBits(int32(numOffsets-1), 5)
// Extended version of decompressor
w.writeBits(int32(offsetCodeCount+((numOffsets-(1+offsetCodeCount))>>3)), 5)
w.writeBits(int32((numOffsets-(1+offsetCodeCount))&0x7), 3)
} else {
w.writeBits(int32(numOffsets-1), 5)
}
w.writeBits(int32(numCodegens-4), 4) w.writeBits(int32(numCodegens-4), 4)
for i := 0; i < numCodegens; i++ { for i := 0; i < numCodegens; i++ {
...@@ -368,24 +359,17 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) { ...@@ -368,24 +359,17 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
tokens = tokens[0 : n+1] tokens = tokens[0 : n+1]
tokens[n] = endBlockMarker tokens[n] = endBlockMarker
totalLength := -1 // Subtract 1 for endBlock.
for _, t := range tokens { for _, t := range tokens {
switch t.typ() { switch t.typ() {
case literalType: case literalType:
w.literalFreq[t.literal()]++ w.literalFreq[t.literal()]++
totalLength++
break
case matchType: case matchType:
length := t.length() length := t.length()
offset := t.offset() offset := t.offset()
totalLength += int(length + 3)
w.literalFreq[lengthCodesStart+lengthCode(length)]++ w.literalFreq[lengthCodesStart+lengthCode(length)]++
w.offsetFreq[offsetCode(offset)]++ w.offsetFreq[offsetCode(offset)]++
break
} }
} }
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
// get the number of literals // get the number of literals
numLiterals := len(w.literalFreq) numLiterals := len(w.literalFreq)
...@@ -394,15 +378,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) { ...@@ -394,15 +378,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
} }
// get the number of offsets // get the number of offsets
numOffsets := len(w.offsetFreq) numOffsets := len(w.offsetFreq)
for numOffsets > 1 && w.offsetFreq[numOffsets-1] == 0 { for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 {
numOffsets-- numOffsets--
} }
if numOffsets == 0 {
// We haven't found a single match. If we want to go with the dynamic encoding,
// we should count at least one offset to be sure that the offset huffman tree could be encoded.
w.offsetFreq[0] = 1
numOffsets = 1
}
w.literalEncoding.generate(w.literalFreq, 15)
w.offsetEncoding.generate(w.offsetFreq, 15)
storedBytes := 0 storedBytes := 0
if input != nil { if input != nil {
storedBytes = len(input) storedBytes = len(input)
} }
var extraBits int64 var extraBits int64
var storedSize int64 var storedSize int64 = math.MaxInt64
if storedBytes <= maxStoreBlockSize && input != nil { if storedBytes <= maxStoreBlockSize && input != nil {
storedSize = int64((storedBytes + 5) * 8) storedSize = int64((storedBytes + 5) * 8)
// We only bother calculating the costs of the extra bits required by // We only bother calculating the costs of the extra bits required by
...@@ -417,34 +411,29 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) { ...@@ -417,34 +411,29 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
// First four offset codes have extra size = 0. // First four offset codes have extra size = 0.
extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode]) extraBits += int64(w.offsetFreq[offsetCode]) * int64(offsetExtraBits[offsetCode])
} }
} else {
storedSize = math.MaxInt32
} }
// Figure out which generates smaller code, fixed Huffman, dynamic // Figure out smallest code.
// Huffman, or just storing the data. // Fixed Huffman baseline.
var fixedSize int64 = math.MaxInt64 var size = int64(3) +
if numOffsets <= offsetCodeCount { fixedLiteralEncoding.bitLength(w.literalFreq) +
fixedSize = int64(3) + fixedOffsetEncoding.bitLength(w.offsetFreq) +
fixedLiteralEncoding.bitLength(w.literalFreq) + extraBits
fixedOffsetEncoding.bitLength(w.offsetFreq) + var literalEncoding = fixedLiteralEncoding
extraBits var offsetEncoding = fixedOffsetEncoding
}
// Dynamic Huffman?
var numCodegens int
// Generate codegen and codegenFrequencies, which indicates how to encode // Generate codegen and codegenFrequencies, which indicates how to encode
// the literalEncoding and the offsetEncoding. // the literalEncoding and the offsetEncoding.
w.generateCodegen(numLiterals, numOffsets) w.generateCodegen(numLiterals, numOffsets)
w.codegenEncoding.generate(w.codegenFreq, 7) w.codegenEncoding.generate(w.codegenFreq, 7)
numCodegens := len(w.codegenFreq) numCodegens = len(w.codegenFreq)
for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 {
numCodegens-- numCodegens--
} }
extensionSummand := 0
if numOffsets > offsetCodeCount {
extensionSummand = 3
}
dynamicHeader := int64(3+5+5+4+(3*numCodegens)) + dynamicHeader := int64(3+5+5+4+(3*numCodegens)) +
// Following line is an extension.
int64(extensionSummand) +
w.codegenEncoding.bitLength(w.codegenFreq) + w.codegenEncoding.bitLength(w.codegenFreq) +
int64(extraBits) + int64(extraBits) +
int64(w.codegenFreq[16]*2) + int64(w.codegenFreq[16]*2) +
...@@ -454,26 +443,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) { ...@@ -454,26 +443,25 @@ func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) {
w.literalEncoding.bitLength(w.literalFreq) + w.literalEncoding.bitLength(w.literalFreq) +
w.offsetEncoding.bitLength(w.offsetFreq) w.offsetEncoding.bitLength(w.offsetFreq)
if storedSize < fixedSize && storedSize < dynamicSize { if dynamicSize < size {
size = dynamicSize
literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding
}
// Stored bytes?
if storedSize < size {
w.writeStoredHeader(storedBytes, eof) w.writeStoredHeader(storedBytes, eof)
w.writeBytes(input[0:storedBytes]) w.writeBytes(input[0:storedBytes])
return return
} }
var literalEncoding *huffmanEncoder
var offsetEncoding *huffmanEncoder
if fixedSize <= dynamicSize { // Huffman.
if literalEncoding == fixedLiteralEncoding {
w.writeFixedHeader(eof) w.writeFixedHeader(eof)
literalEncoding = fixedLiteralEncoding
offsetEncoding = fixedOffsetEncoding
} else { } else {
// Write the header.
w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof)
literalEncoding = w.literalEncoding
offsetEncoding = w.offsetEncoding
} }
// Write the tokens.
for _, t := range tokens { for _, t := range tokens {
switch t.typ() { switch t.typ() {
case literalType: case literalType:
......
...@@ -363,7 +363,12 @@ func (s literalNodeSorter) Less(i, j int) bool { ...@@ -363,7 +363,12 @@ func (s literalNodeSorter) Less(i, j int) bool {
func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] } func (s literalNodeSorter) Swap(i, j int) { s.a[i], s.a[j] = s.a[j], s.a[i] }
func sortByFreq(a []literalNode) { func sortByFreq(a []literalNode) {
s := &literalNodeSorter{a, func(i, j int) bool { return a[i].freq < a[j].freq }} s := &literalNodeSorter{a, func(i, j int) bool {
if a[i].freq == a[j].freq {
return a[i].literal < a[j].literal
}
return a[i].freq < a[j].freq
}}
sort.Sort(s) sort.Sort(s)
} }
......
...@@ -77,8 +77,6 @@ type huffmanDecoder struct { ...@@ -77,8 +77,6 @@ type huffmanDecoder struct {
// Initialize Huffman decoding tables from array of code lengths. // Initialize Huffman decoding tables from array of code lengths.
func (h *huffmanDecoder) init(bits []int) bool { func (h *huffmanDecoder) init(bits []int) bool {
// TODO(rsc): Return false sometimes.
// Count number of codes of each length, // Count number of codes of each length,
// compute min and max length. // compute min and max length.
var count [maxCodeLen + 1]int var count [maxCodeLen + 1]int
...@@ -197,9 +195,8 @@ type Reader interface { ...@@ -197,9 +195,8 @@ type Reader interface {
// Decompress state. // Decompress state.
type decompressor struct { type decompressor struct {
// Input/output sources. // Input source.
r Reader r Reader
w io.Writer
roffset int64 roffset int64
woffset int64 woffset int64
...@@ -222,38 +219,79 @@ type decompressor struct { ...@@ -222,38 +219,79 @@ type decompressor struct {
// Temporary buffer (avoids repeated allocation). // Temporary buffer (avoids repeated allocation).
buf [4]byte buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
final bool
err os.Error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
} }
func (f *decompressor) inflate() (err os.Error) { func (f *decompressor) nextBlock() {
final := false if f.final {
for err == nil && !final { if f.hw != f.hp {
for f.nb < 1+2 { f.flush((*decompressor).nextBlock)
if err = f.moreBits(); err != nil { return
return
}
} }
final = f.b&1 == 1 f.err = os.EOF
f.b >>= 1 return
typ := f.b & 3 }
f.b >>= 2 for f.nb < 1+2 {
f.nb -= 1 + 2 if f.err = f.moreBits(); f.err != nil {
switch typ { return
case 0: }
err = f.dataBlock() }
case 1: f.final = f.b&1 == 1
// compressed, fixed Huffman tables f.b >>= 1
err = f.decodeBlock(&fixedHuffmanDecoder, nil) typ := f.b & 3
case 2: f.b >>= 2
// compressed, dynamic Huffman tables f.nb -= 1 + 2
if err = f.readHuffman(); err == nil { switch typ {
err = f.decodeBlock(&f.h1, &f.h2) case 0:
} f.dataBlock()
default: case 1:
// 3 is reserved. // compressed, fixed Huffman tables
err = CorruptInputError(f.roffset) f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2:
// compressed, dynamic Huffman tables
if f.err = f.readHuffman(); f.err != nil {
break
}
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default:
// 3 is reserved.
f.err = CorruptInputError(f.roffset)
}
}
func (f *decompressor) Read(b []byte) (int, os.Error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
return n, nil
}
if f.err != nil {
return 0, f.err
} }
f.step(f)
} }
return panic("unreachable")
}
func (f *decompressor) Close() os.Error {
if f.err == os.EOF {
return nil
}
return f.err
} }
// RFC 1951 section 3.2.7. // RFC 1951 section 3.2.7.
...@@ -358,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error { ...@@ -358,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
// hl and hd are the Huffman states for the lit/length values // hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the // and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks. // fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { func (f *decompressor) huffmanBlock() {
for { for {
v, err := f.huffSym(hl) v, err := f.huffSym(f.hl)
if err != nil { if err != nil {
return err f.err = err
return
} }
var n uint // number of bits extra var n uint // number of bits extra
var length int var length int
...@@ -371,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -371,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hist[f.hp] = byte(v) f.hist[f.hp] = byte(v)
f.hp++ f.hp++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After the flush, continue this loop.
return err f.flush((*decompressor).huffmanBlock)
} return
} }
continue continue
case v == 256: case v == 256:
return nil // Done with huffman block; read next block.
f.step = (*decompressor).nextBlock
return
// otherwise, reference to older data // otherwise, reference to older data
case v < 265: case v < 265:
length = v - (257 - 3) length = v - (257 - 3)
...@@ -404,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -404,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
if n > 0 { if n > 0 {
for f.nb < n { for f.nb < n {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
length += int(f.b & uint32(1<<n-1)) length += int(f.b & uint32(1<<n-1))
...@@ -413,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -413,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
} }
var dist int var dist int
if hd == nil { if f.hd == nil {
for f.nb < 5 { for f.nb < 5 {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
dist = int(reverseByte[(f.b&0x1F)<<3]) dist = int(reverseByte[(f.b&0x1F)<<3])
f.b >>= 5 f.b >>= 5
f.nb -= 5 f.nb -= 5
} else { } else {
if dist, err = f.huffSym(hd); err != nil { if dist, err = f.huffSym(f.hd); err != nil {
return err f.err = err
return
} }
} }
...@@ -432,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -432,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
case dist < 4: case dist < 4:
dist++ dist++
case dist >= 30: case dist >= 30:
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
default: default:
nb := uint(dist-2) >> 1 nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more. // have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb extra := (dist & 1) << nb
for f.nb < nb { for f.nb < nb {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
extra |= int(f.b & uint32(1<<nb-1)) extra |= int(f.b & uint32(1<<nb-1))
...@@ -450,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -450,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
// Copy history[-dist:-dist+length] into output. // Copy history[-dist:-dist+length] into output.
if dist > len(f.hist) { if dist > len(f.hist) {
return InternalError("bad history distance") f.err = InternalError("bad history distance")
return
} }
// No check on length; encoding can be prescient. // No check on length; encoding can be prescient.
if !f.hfull && dist > f.hp { if !f.hfull && dist > f.hp {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
p := f.hp - dist p := f.hp - dist
...@@ -467,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -467,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hp++ f.hp++
p++ p++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After flush continue copying out of history.
return err f.copyLen = length - (i + 1)
} f.copyDist = dist
f.flush((*decompressor).copyHuff)
return
} }
if p == len(f.hist) { if p == len(f.hist) {
p = 0 p = 0
...@@ -479,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { ...@@ -479,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
panic("unreached") panic("unreached")
} }
func (f *decompressor) copyHuff() {
length := f.copyLen
dist := f.copyDist
p := f.hp - dist
if p < 0 {
p += len(f.hist)
}
for i := 0; i < length; i++ {
f.hist[f.hp] = f.hist[p]
f.hp++
p++
if f.hp == len(f.hist) {
f.copyLen = length - (i + 1)
f.flush((*decompressor).copyHuff)
return
}
if p == len(f.hist) {
p = 0
}
}
// Continue processing Huffman block.
f.huffmanBlock()
}
// Copy a single uncompressed data block from input to output. // Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() os.Error { func (f *decompressor) dataBlock() {
// Uncompressed. // Uncompressed.
// Discard current half-byte. // Discard current half-byte.
f.nb = 0 f.nb = 0
...@@ -490,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error { ...@@ -490,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
nr, err := io.ReadFull(f.r, f.buf[0:4]) nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr) f.roffset += int64(nr)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n := int(f.buf[0]) | int(f.buf[1])<<8 n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8 nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) { if uint16(nn) != uint16(^n) {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
if n == 0 { if n == 0 {
// 0-length block means sync // 0-length block means sync
return f.flush() f.flush((*decompressor).nextBlock)
return
} }
// Read len bytes into history, f.copyLen = n
// writing as history fills. f.copyData()
}
func (f *decompressor) copyData() {
// Read f.dataLen bytes into history,
// pausing for reads as history fills.
n := f.copyLen
for n > 0 { for n > 0 {
m := len(f.hist) - f.hp m := len(f.hist) - f.hp
if m > n { if m > n {
...@@ -513,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error { ...@@ -513,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m]) m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
f.roffset += int64(m) f.roffset += int64(m)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n -= m n -= m
f.hp += m f.hp += m
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { f.copyLen = n
return err f.flush((*decompressor).copyData)
} return
} }
} }
return nil f.step = (*decompressor).nextBlock
} }
func (f *decompressor) setDict(dict []byte) { func (f *decompressor) setDict(dict []byte) {
...@@ -579,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) { ...@@ -579,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
} }
// Flush any buffered output to the underlying writer. // Flush any buffered output to the underlying writer.
func (f *decompressor) flush() os.Error { func (f *decompressor) flush(step func(*decompressor)) {
if f.hw == f.hp { f.toRead = f.hist[f.hw:f.hp]
return nil
}
n, err := f.w.Write(f.hist[f.hw:f.hp])
if n != f.hp-f.hw && err == nil {
err = io.ErrShortWrite
}
if err != nil {
return &WriteError{f.woffset, err}
}
f.woffset += int64(f.hp - f.hw) f.woffset += int64(f.hp - f.hw)
f.hw = f.hp f.hw = f.hp
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
...@@ -597,7 +673,7 @@ func (f *decompressor) flush() os.Error { ...@@ -597,7 +673,7 @@ func (f *decompressor) flush() os.Error {
f.hw = 0 f.hw = 0
f.hfull = true f.hfull = true
} }
return nil f.step = step
} }
func makeReader(r io.Reader) Reader { func makeReader(r io.Reader) Reader {
...@@ -607,30 +683,15 @@ func makeReader(r io.Reader) Reader { ...@@ -607,30 +683,15 @@ func makeReader(r io.Reader) Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
// decompress reads DEFLATE-compressed data from r and writes
// the uncompressed data to w.
func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
f.r = makeReader(r)
f.w = w
f.woffset = 0
if err := f.inflate(); err != nil {
return err
}
if err := f.flush(); err != nil {
return err
}
return nil
}
// NewReader returns a new ReadCloser that can be used // NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r. It is the caller's // to read the uncompressed version of r. It is the caller's
// responsibility to call Close on the ReadCloser when // responsibility to call Close on the ReadCloser when
// finished reading. // finished reading.
func NewReader(r io.Reader) io.ReadCloser { func NewReader(r io.Reader) io.ReadCloser {
var f decompressor var f decompressor
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }
// NewReaderDict is like NewReader but initializes the reader // NewReaderDict is like NewReader but initializes the reader
...@@ -641,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser { ...@@ -641,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor var f decompressor
f.setDict(dict) f.setDict(dict)
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }
...@@ -36,8 +36,8 @@ func makeReader(r io.Reader) flate.Reader { ...@@ -36,8 +36,8 @@ func makeReader(r io.Reader) flate.Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
var HeaderError os.Error = os.ErrorString("invalid gzip header") var HeaderError = os.NewError("invalid gzip header")
var ChecksumError os.Error = os.ErrorString("gzip checksum error") var ChecksumError = os.NewError("gzip checksum error")
// The gzip file stores a header giving metadata about the compressed file. // The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Compressor and Decompressor structs. // That header is exposed as the fields of the Compressor and Decompressor structs.
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
) )
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the // pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the
// writer end and ifunc at the reader end. // writer end and cfunc at the reader end.
func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) { func pipe(t *testing.T, dfunc func(*Compressor), cfunc func(*Decompressor)) {
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
defer piper.Close() defer piper.Close()
......
...@@ -32,13 +32,49 @@ const ( ...@@ -32,13 +32,49 @@ const (
MSB MSB
) )
const (
maxWidth = 12
decoderInvalidCode = 0xffff
flushBuffer = 1 << maxWidth
)
// decoder is the state from which the readXxx method converts a byte // decoder is the state from which the readXxx method converts a byte
// stream into a code stream. // stream into a code stream.
type decoder struct { type decoder struct {
r io.ByteReader r io.ByteReader
bits uint32 bits uint32
nBits uint nBits uint
width uint width uint
read func(*decoder) (uint16, os.Error) // readLSB or readMSB
litWidth int // width in bits of literal codes
err os.Error
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
// overflow is the code at which hi overflows the code width.
// last is the most recently seen code, or decoderInvalidCode.
clear, eof, hi, overflow, last uint16
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// output is the temporary output buffer.
// Literal codes are accumulated from the start of the buffer.
// Non-literal codes decode to a sequence of suffixes that are first
// written right-to-left from the end of the buffer before being copied
// to the start of the buffer.
// It is flushed when it contains >= 1<<maxWidth bytes,
// so that there is always room to decode an entire code.
output [2 * 1 << maxWidth]byte
o int // write index into output
toRead []byte // bytes to return from Read
} }
// readLSB returns the next code for "Least Significant Bits first" data. // readLSB returns the next code for "Least Significant Bits first" data.
...@@ -73,119 +109,113 @@ func (d *decoder) readMSB() (uint16, os.Error) { ...@@ -73,119 +109,113 @@ func (d *decoder) readMSB() (uint16, os.Error) {
return code, nil return code, nil
} }
// decode decompresses bytes from r and writes them to pw. func (d *decoder) Read(b []byte) (int, os.Error) {
// read specifies how to decode bytes into codes. for {
// litWidth is the width in bits of literal codes. if len(d.toRead) > 0 {
func decode(r io.Reader, read func(*decoder) (uint16, os.Error), litWidth int, pw *io.PipeWriter) { n := copy(b, d.toRead)
br, ok := r.(io.ByteReader) d.toRead = d.toRead[n:]
if !ok { return n, nil
br = bufio.NewReader(r) }
if d.err != nil {
return 0, d.err
}
d.decode()
} }
pw.CloseWithError(decode1(pw, br, read, uint(litWidth))) panic("unreachable")
} }
func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os.Error), litWidth uint) os.Error { // decode decompresses bytes from r and leaves them in d.toRead.
const ( // read specifies how to decode bytes into codes.
maxWidth = 12 // litWidth is the width in bits of literal codes.
invalidCode = 0xffff func (d *decoder) decode() {
)
d := decoder{r, 0, 0, 1 + litWidth}
w := bufio.NewWriter(pw)
// The first 1<<litWidth codes are literal codes.
// The next two codes mean clear and EOF.
// Other valid codes are in the range [lo, hi] where lo := clear + 2,
// with the upper bound incrementing on each code seen.
clear := uint16(1) << litWidth
eof, hi := clear+1, clear+1
// overflow is the code at which hi overflows the code width.
overflow := uint16(1) << d.width
var (
// Each code c in [lo, hi] expands to two or more bytes. For c != hi:
// suffix[c] is the last of these bytes.
// prefix[c] is the code for all but the last byte.
// This code can either be a literal code or another code in [lo, c).
// The c == hi case is a special case.
suffix [1 << maxWidth]uint8
prefix [1 << maxWidth]uint16
// buf is a scratch buffer for reconstituting the bytes that a code expands to.
// Code suffixes are written right-to-left from the end of the buffer.
buf [1 << maxWidth]byte
)
// Loop over the code stream, converting codes into decompressed bytes. // Loop over the code stream, converting codes into decompressed bytes.
last := uint16(invalidCode)
for { for {
code, err := read(&d) code, err := d.read(d)
if err != nil { if err != nil {
if err == os.EOF { if err == os.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return err d.err = err
return
} }
switch { switch {
case code < clear: case code < d.clear:
// We have a literal code. // We have a literal code.
if err := w.WriteByte(uint8(code)); err != nil { d.output[d.o] = uint8(code)
return err d.o++
} if d.last != decoderInvalidCode {
if last != invalidCode {
// Save what the hi code expands to. // Save what the hi code expands to.
suffix[hi] = uint8(code) d.suffix[d.hi] = uint8(code)
prefix[hi] = last d.prefix[d.hi] = d.last
} }
case code == clear: case code == d.clear:
d.width = 1 + litWidth d.width = 1 + uint(d.litWidth)
hi = eof d.hi = d.eof
overflow = 1 << d.width d.overflow = 1 << d.width
last = invalidCode d.last = decoderInvalidCode
continue continue
case code == eof: case code == d.eof:
return w.Flush() d.flush()
case code <= hi: d.err = os.EOF
c, i := code, len(buf)-1 return
if code == hi { case code <= d.hi:
c, i := code, len(d.output)-1
if code == d.hi {
// code == hi is a special case which expands to the last expansion // code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk // followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code. // the prefix chain until we find a literal code.
c = last c = d.last
for c >= clear { for c >= d.clear {
c = prefix[c] c = d.prefix[c]
} }
buf[i] = uint8(c) d.output[i] = uint8(c)
i-- i--
c = last c = d.last
} }
// Copy the suffix chain into buf and then write that to w. // Copy the suffix chain into output and then write that to w.
for c >= clear { for c >= d.clear {
buf[i] = suffix[c] d.output[i] = d.suffix[c]
i-- i--
c = prefix[c] c = d.prefix[c]
} }
buf[i] = uint8(c) d.output[i] = uint8(c)
if _, err := w.Write(buf[i:]); err != nil { d.o += copy(d.output[d.o:], d.output[i:])
return err if d.last != decoderInvalidCode {
}
if last != invalidCode {
// Save what the hi code expands to. // Save what the hi code expands to.
suffix[hi] = uint8(c) d.suffix[d.hi] = uint8(c)
prefix[hi] = last d.prefix[d.hi] = d.last
} }
default: default:
return os.NewError("lzw: invalid code") d.err = os.NewError("lzw: invalid code")
return
} }
last, hi = code, hi+1 d.last, d.hi = code, d.hi+1
if hi >= overflow { if d.hi >= d.overflow {
if d.width == maxWidth { if d.width == maxWidth {
last = invalidCode d.last = decoderInvalidCode
continue } else {
d.width++
d.overflow <<= 1
} }
d.width++ }
overflow <<= 1 if d.o >= flushBuffer {
d.flush()
return
} }
} }
panic("unreachable") panic("unreachable")
} }
func (d *decoder) flush() {
d.toRead = d.output[:d.o]
d.o = 0
}
func (d *decoder) Close() os.Error {
d.err = os.EINVAL // in case any Reads come along
return nil
}
// NewReader creates a new io.ReadCloser that satisfies reads by decompressing // NewReader creates a new io.ReadCloser that satisfies reads by decompressing
// the data read from r. // the data read from r.
// It is the caller's responsibility to call Close on the ReadCloser when // It is the caller's responsibility to call Close on the ReadCloser when
...@@ -193,21 +223,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os ...@@ -193,21 +223,31 @@ func decode1(pw *io.PipeWriter, r io.ByteReader, read func(*decoder) (uint16, os
// The number of bits to use for literal codes, litWidth, must be in the // The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. // range [2,8] and is typically 8.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser { func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
pr, pw := io.Pipe() d := new(decoder)
var read func(*decoder) (uint16, os.Error)
switch order { switch order {
case LSB: case LSB:
read = (*decoder).readLSB d.read = (*decoder).readLSB
case MSB: case MSB:
read = (*decoder).readMSB d.read = (*decoder).readMSB
default: default:
pw.CloseWithError(os.NewError("lzw: unknown order")) d.err = os.NewError("lzw: unknown order")
return pr return d
} }
if litWidth < 2 || 8 < litWidth { if litWidth < 2 || 8 < litWidth {
pw.CloseWithError(fmt.Errorf("lzw: litWidth %d out of range", litWidth)) d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
return pr return d
} }
go decode(r, read, litWidth, pw) if br, ok := r.(io.ByteReader); ok {
return pr d.r = br
} else {
d.r = bufio.NewReader(r)
}
d.litWidth = litWidth
d.width = 1 + uint(litWidth)
d.clear = uint16(1) << uint(litWidth)
d.eof, d.hi = d.clear+1, d.clear+1
d.overflow = uint16(1) << d.width
d.last = decoderInvalidCode
return d
} }
...@@ -84,7 +84,7 @@ var lzwTests = []lzwTest{ ...@@ -84,7 +84,7 @@ var lzwTests = []lzwTest{
func TestReader(t *testing.T) { func TestReader(t *testing.T) {
b := bytes.NewBuffer(nil) b := bytes.NewBuffer(nil)
for _, tt := range lzwTests { for _, tt := range lzwTests {
d := strings.Split(tt.desc, ";", -1) d := strings.Split(tt.desc, ";")
var order Order var order Order
switch d[1] { switch d[1] {
case "LSB": case "LSB":
......
...@@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) { ...@@ -77,13 +77,13 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1) t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return return
} }
if len(b0) != len(b1) { if len(b1) != len(b0) {
t.Errorf("%s (order=%d litWidth=%d): length mismatch %d versus %d", fn, order, litWidth, len(b0), len(b1)) t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
return return
} }
for i := 0; i < len(b0); i++ { for i := 0; i < len(b0); i++ {
if b0[i] != b1[i] { if b1[i] != b0[i] {
t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x versus 0x%02x\n", fn, order, litWidth, i, b0[i], b1[i]) t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
return return
} }
} }
......
...@@ -34,9 +34,9 @@ import ( ...@@ -34,9 +34,9 @@ import (
const zlibDeflate = 8 const zlibDeflate = 8
var ChecksumError os.Error = os.ErrorString("zlib checksum error") var ChecksumError = os.NewError("zlib checksum error")
var HeaderError os.Error = os.ErrorString("invalid zlib header") var HeaderError = os.NewError("invalid zlib header")
var DictionaryError os.Error = os.ErrorString("invalid zlib dictionary") var DictionaryError = os.NewError("invalid zlib dictionary")
type reader struct { type reader struct {
r flate.Reader r flate.Reader
......
...@@ -89,7 +89,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) { ...@@ -89,7 +89,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, os.Error) {
} }
} }
z.w = w z.w = w
z.compressor = flate.NewWriter(w, level) z.compressor = flate.NewWriterDict(w, level, dict)
z.digest = adler32.New() z.digest = adler32.New()
return z, nil return z, nil
} }
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
package zlib package zlib
import ( import (
"bytes"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
...@@ -16,15 +18,13 @@ var filenames = []string{ ...@@ -16,15 +18,13 @@ var filenames = []string{
"../testdata/pi.txt", "../testdata/pi.txt",
} }
var data = []string{
"test a reasonable sized string that can be compressed",
}
// Tests that compressing and then decompressing the given file at the given compression level and dictionary // Tests that compressing and then decompressing the given file at the given compression level and dictionary
// yields equivalent bytes to the original file. // yields equivalent bytes to the original file.
func testFileLevelDict(t *testing.T, fn string, level int, d string) { func testFileLevelDict(t *testing.T, fn string, level int, d string) {
// Read dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Read the file, as golden output. // Read the file, as golden output.
golden, err := os.Open(fn) golden, err := os.Open(fn)
if err != nil { if err != nil {
...@@ -32,17 +32,25 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) { ...@@ -32,17 +32,25 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return return
} }
defer golden.Close() defer golden.Close()
b0, err0 := ioutil.ReadAll(golden)
// Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end. if err0 != nil {
raw, err := os.Open(fn) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
} }
testLevelDict(t, fn, b0, level, d)
}
func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
// Make dictionary, if given.
var dict []byte
if d != "" {
dict = []byte(d)
}
// Push data through a pipe that compresses at the write end, and decompresses at the read end.
piper, pipew := io.Pipe() piper, pipew := io.Pipe()
defer piper.Close() defer piper.Close()
go func() { go func() {
defer raw.Close()
defer pipew.Close() defer pipew.Close()
zlibw, err := NewWriterDict(pipew, level, dict) zlibw, err := NewWriterDict(pipew, level, dict)
if err != nil { if err != nil {
...@@ -50,25 +58,14 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) { ...@@ -50,25 +58,14 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
return return
} }
defer zlibw.Close() defer zlibw.Close()
var b [1024]byte _, err = zlibw.Write(b0)
for { if err == os.EPIPE {
n, err0 := raw.Read(b[0:]) // Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
if err0 != nil && err0 != os.EOF { return
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0) }
return if err != nil {
} t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
_, err1 := zlibw.Write(b[0:n]) return
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return
}
if err0 == os.EOF {
break
}
} }
}() }()
zlibr, err := NewReaderDict(piper, dict) zlibr, err := NewReaderDict(piper, dict)
...@@ -78,13 +75,8 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) { ...@@ -78,13 +75,8 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
} }
defer zlibr.Close() defer zlibr.Close()
// Compare the two. // Compare the decompressed data.
b0, err0 := ioutil.ReadAll(golden)
b1, err1 := ioutil.ReadAll(zlibr) b1, err1 := ioutil.ReadAll(zlibr)
if err0 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err0)
return
}
if err1 != nil { if err1 != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err1)
return return
...@@ -102,6 +94,18 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) { ...@@ -102,6 +94,18 @@ func testFileLevelDict(t *testing.T, fn string, level int, d string) {
} }
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
for i, s := range data {
b := []byte(s)
tag := fmt.Sprintf("#%d", i)
testLevelDict(t, tag, b, DefaultCompression, "")
testLevelDict(t, tag, b, NoCompression, "")
for level := BestSpeed; level <= BestCompression; level++ {
testLevelDict(t, tag, b, level, "")
}
}
}
func TestWriterBig(t *testing.T) {
for _, fn := range filenames { for _, fn := range filenames {
testFileLevelDict(t, fn, DefaultCompression, "") testFileLevelDict(t, fn, DefaultCompression, "")
testFileLevelDict(t, fn, NoCompression, "") testFileLevelDict(t, fn, NoCompression, "")
...@@ -121,3 +125,20 @@ func TestWriterDict(t *testing.T) { ...@@ -121,3 +125,20 @@ func TestWriterDict(t *testing.T) {
} }
} }
} }
func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
buf := bytes.NewBuffer(nil)
compressor, err := NewWriterDict(buf, BestCompression, input)
if err != nil {
t.Errorf("error in NewWriterDict: %s", err)
return
}
compressor.Write(input)
compressor.Close()
const expectedMaxSize = 25
output := buf.Bytes()
if len(output) > expectedMaxSize {
t.Errorf("result too large (got %d, want <= %d bytes). Is the dictionary being used?", len(output), expectedMaxSize)
}
}
...@@ -21,8 +21,7 @@ type Interface interface { ...@@ -21,8 +21,7 @@ type Interface interface {
Pop() interface{} Pop() interface{}
} }
// A heap must be initialized before any of the heap operations
// A heaper must be initialized before any of the heap operations
// can be used. Init is idempotent with respect to the heap invariants // can be used. Init is idempotent with respect to the heap invariants
// and may be called whenever the heap invariants may have been invalidated. // and may be called whenever the heap invariants may have been invalidated.
// Its complexity is O(n) where n = h.Len(). // Its complexity is O(n) where n = h.Len().
...@@ -35,7 +34,6 @@ func Init(h Interface) { ...@@ -35,7 +34,6 @@ func Init(h Interface) {
} }
} }
// Push pushes the element x onto the heap. The complexity is // Push pushes the element x onto the heap. The complexity is
// O(log(n)) where n = h.Len(). // O(log(n)) where n = h.Len().
// //
...@@ -44,7 +42,6 @@ func Push(h Interface, x interface{}) { ...@@ -44,7 +42,6 @@ func Push(h Interface, x interface{}) {
up(h, h.Len()-1) up(h, h.Len()-1)
} }
// Pop removes the minimum element (according to Less) from the heap // Pop removes the minimum element (according to Less) from the heap
// and returns it. The complexity is O(log(n)) where n = h.Len(). // and returns it. The complexity is O(log(n)) where n = h.Len().
// Same as Remove(h, 0). // Same as Remove(h, 0).
...@@ -56,7 +53,6 @@ func Pop(h Interface) interface{} { ...@@ -56,7 +53,6 @@ func Pop(h Interface) interface{} {
return h.Pop() return h.Pop()
} }
// Remove removes the element at index i from the heap. // Remove removes the element at index i from the heap.
// The complexity is O(log(n)) where n = h.Len(). // The complexity is O(log(n)) where n = h.Len().
// //
...@@ -70,7 +66,6 @@ func Remove(h Interface, i int) interface{} { ...@@ -70,7 +66,6 @@ func Remove(h Interface, i int) interface{} {
return h.Pop() return h.Pop()
} }
func up(h Interface, j int) { func up(h Interface, j int) {
for { for {
i := (j - 1) / 2 // parent i := (j - 1) / 2 // parent
...@@ -82,7 +77,6 @@ func up(h Interface, j int) { ...@@ -82,7 +77,6 @@ func up(h Interface, j int) {
} }
} }
func down(h Interface, i, n int) { func down(h Interface, i, n int) {
for { for {
j1 := 2*i + 1 j1 := 2*i + 1
......
...@@ -10,17 +10,14 @@ import ( ...@@ -10,17 +10,14 @@ import (
. "container/heap" . "container/heap"
) )
type myHeap struct { type myHeap struct {
// A vector.Vector implements sort.Interface except for Less, // A vector.Vector implements sort.Interface except for Less,
// and it implements Push and Pop as required for heap.Interface. // and it implements Push and Pop as required for heap.Interface.
vector.Vector vector.Vector
} }
func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) } func (h *myHeap) Less(i, j int) bool { return h.At(i).(int) < h.At(j).(int) }
func (h *myHeap) verify(t *testing.T, i int) { func (h *myHeap) verify(t *testing.T, i int) {
n := h.Len() n := h.Len()
j1 := 2*i + 1 j1 := 2*i + 1
...@@ -41,7 +38,6 @@ func (h *myHeap) verify(t *testing.T, i int) { ...@@ -41,7 +38,6 @@ func (h *myHeap) verify(t *testing.T, i int) {
} }
} }
func TestInit0(t *testing.T) { func TestInit0(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 20; i > 0; i-- { for i := 20; i > 0; i-- {
...@@ -59,7 +55,6 @@ func TestInit0(t *testing.T) { ...@@ -59,7 +55,6 @@ func TestInit0(t *testing.T) {
} }
} }
func TestInit1(t *testing.T) { func TestInit1(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 20; i > 0; i-- { for i := 20; i > 0; i-- {
...@@ -77,7 +72,6 @@ func TestInit1(t *testing.T) { ...@@ -77,7 +72,6 @@ func TestInit1(t *testing.T) {
} }
} }
func Test(t *testing.T) { func Test(t *testing.T) {
h := new(myHeap) h := new(myHeap)
h.verify(t, 0) h.verify(t, 0)
...@@ -105,7 +99,6 @@ func Test(t *testing.T) { ...@@ -105,7 +99,6 @@ func Test(t *testing.T) {
} }
} }
func TestRemove0(t *testing.T) { func TestRemove0(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
...@@ -123,7 +116,6 @@ func TestRemove0(t *testing.T) { ...@@ -123,7 +116,6 @@ func TestRemove0(t *testing.T) {
} }
} }
func TestRemove1(t *testing.T) { func TestRemove1(t *testing.T) {
h := new(myHeap) h := new(myHeap)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
...@@ -140,7 +132,6 @@ func TestRemove1(t *testing.T) { ...@@ -140,7 +132,6 @@ func TestRemove1(t *testing.T) {
} }
} }
func TestRemove2(t *testing.T) { func TestRemove2(t *testing.T) {
N := 10 N := 10
......
...@@ -16,14 +16,12 @@ type Ring struct { ...@@ -16,14 +16,12 @@ type Ring struct {
Value interface{} // for use by client; untouched by this library Value interface{} // for use by client; untouched by this library
} }
func (r *Ring) init() *Ring { func (r *Ring) init() *Ring {
r.next = r r.next = r
r.prev = r r.prev = r
return r return r
} }
// Next returns the next ring element. r must not be empty. // Next returns the next ring element. r must not be empty.
func (r *Ring) Next() *Ring { func (r *Ring) Next() *Ring {
if r.next == nil { if r.next == nil {
...@@ -32,7 +30,6 @@ func (r *Ring) Next() *Ring { ...@@ -32,7 +30,6 @@ func (r *Ring) Next() *Ring {
return r.next return r.next
} }
// Prev returns the previous ring element. r must not be empty. // Prev returns the previous ring element. r must not be empty.
func (r *Ring) Prev() *Ring { func (r *Ring) Prev() *Ring {
if r.next == nil { if r.next == nil {
...@@ -41,7 +38,6 @@ func (r *Ring) Prev() *Ring { ...@@ -41,7 +38,6 @@ func (r *Ring) Prev() *Ring {
return r.prev return r.prev
} }
// Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0) // Move moves n % r.Len() elements backward (n < 0) or forward (n >= 0)
// in the ring and returns that ring element. r must not be empty. // in the ring and returns that ring element. r must not be empty.
// //
...@@ -62,7 +58,6 @@ func (r *Ring) Move(n int) *Ring { ...@@ -62,7 +58,6 @@ func (r *Ring) Move(n int) *Ring {
return r return r
} }
// New creates a ring of n elements. // New creates a ring of n elements.
func New(n int) *Ring { func New(n int) *Ring {
if n <= 0 { if n <= 0 {
...@@ -79,7 +74,6 @@ func New(n int) *Ring { ...@@ -79,7 +74,6 @@ func New(n int) *Ring {
return r return r
} }
// Link connects ring r with with ring s such that r.Next() // Link connects ring r with with ring s such that r.Next()
// becomes s and returns the original value for r.Next(). // becomes s and returns the original value for r.Next().
// r must not be empty. // r must not be empty.
...@@ -110,7 +104,6 @@ func (r *Ring) Link(s *Ring) *Ring { ...@@ -110,7 +104,6 @@ func (r *Ring) Link(s *Ring) *Ring {
return n return n
} }
// Unlink removes n % r.Len() elements from the ring r, starting // Unlink removes n % r.Len() elements from the ring r, starting
// at r.Next(). If n % r.Len() == 0, r remains unchanged. // at r.Next(). If n % r.Len() == 0, r remains unchanged.
// The result is the removed subring. r must not be empty. // The result is the removed subring. r must not be empty.
...@@ -122,7 +115,6 @@ func (r *Ring) Unlink(n int) *Ring { ...@@ -122,7 +115,6 @@ func (r *Ring) Unlink(n int) *Ring {
return r.Link(r.Move(n + 1)) return r.Link(r.Move(n + 1))
} }
// Len computes the number of elements in ring r. // Len computes the number of elements in ring r.
// It executes in time proportional to the number of elements. // It executes in time proportional to the number of elements.
// //
...@@ -137,7 +129,6 @@ func (r *Ring) Len() int { ...@@ -137,7 +129,6 @@ func (r *Ring) Len() int {
return n return n
} }
// Do calls function f on each element of the ring, in forward order. // Do calls function f on each element of the ring, in forward order.
// The behavior of Do is undefined if f changes *r. // The behavior of Do is undefined if f changes *r.
func (r *Ring) Do(f func(interface{})) { func (r *Ring) Do(f func(interface{})) {
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"testing" "testing"
) )
// For debugging - keep around. // For debugging - keep around.
func dump(r *Ring) { func dump(r *Ring) {
if r == nil { if r == nil {
...@@ -24,7 +23,6 @@ func dump(r *Ring) { ...@@ -24,7 +23,6 @@ func dump(r *Ring) {
fmt.Println() fmt.Println()
} }
func verify(t *testing.T, r *Ring, N int, sum int) { func verify(t *testing.T, r *Ring, N int, sum int) {
// Len // Len
n := r.Len() n := r.Len()
...@@ -96,7 +94,6 @@ func verify(t *testing.T, r *Ring, N int, sum int) { ...@@ -96,7 +94,6 @@ func verify(t *testing.T, r *Ring, N int, sum int) {
} }
} }
func TestCornerCases(t *testing.T) { func TestCornerCases(t *testing.T) {
var ( var (
r0 *Ring r0 *Ring
...@@ -118,7 +115,6 @@ func TestCornerCases(t *testing.T) { ...@@ -118,7 +115,6 @@ func TestCornerCases(t *testing.T) {
verify(t, &r1, 1, 0) verify(t, &r1, 1, 0)
} }
func makeN(n int) *Ring { func makeN(n int) *Ring {
r := New(n) r := New(n)
for i := 1; i <= n; i++ { for i := 1; i <= n; i++ {
...@@ -130,7 +126,6 @@ func makeN(n int) *Ring { ...@@ -130,7 +126,6 @@ func makeN(n int) *Ring {
func sumN(n int) int { return (n*n + n) / 2 } func sumN(n int) int { return (n*n + n) / 2 }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
r := New(i) r := New(i)
...@@ -142,7 +137,6 @@ func TestNew(t *testing.T) { ...@@ -142,7 +137,6 @@ func TestNew(t *testing.T) {
} }
} }
func TestLink1(t *testing.T) { func TestLink1(t *testing.T) {
r1a := makeN(1) r1a := makeN(1)
var r1b Ring var r1b Ring
...@@ -163,7 +157,6 @@ func TestLink1(t *testing.T) { ...@@ -163,7 +157,6 @@ func TestLink1(t *testing.T) {
verify(t, r2b, 1, 0) verify(t, r2b, 1, 0)
} }
func TestLink2(t *testing.T) { func TestLink2(t *testing.T) {
var r0 *Ring var r0 *Ring
r1a := &Ring{Value: 42} r1a := &Ring{Value: 42}
...@@ -183,7 +176,6 @@ func TestLink2(t *testing.T) { ...@@ -183,7 +176,6 @@ func TestLink2(t *testing.T) {
verify(t, r10, 12, sumN(10)+42+77) verify(t, r10, 12, sumN(10)+42+77)
} }
func TestLink3(t *testing.T) { func TestLink3(t *testing.T) {
var r Ring var r Ring
n := 1 n := 1
...@@ -193,7 +185,6 @@ func TestLink3(t *testing.T) { ...@@ -193,7 +185,6 @@ func TestLink3(t *testing.T) {
} }
} }
func TestUnlink(t *testing.T) { func TestUnlink(t *testing.T) {
r10 := makeN(10) r10 := makeN(10)
s10 := r10.Move(6) s10 := r10.Move(6)
...@@ -215,7 +206,6 @@ func TestUnlink(t *testing.T) { ...@@ -215,7 +206,6 @@ func TestUnlink(t *testing.T) {
verify(t, r10, 9, sum10-2) verify(t, r10, 9, sum10-2)
} }
func TestLinkUnlink(t *testing.T) { func TestLinkUnlink(t *testing.T) {
for i := 1; i < 4; i++ { for i := 1; i < 4; i++ {
ri := New(i) ri := New(i)
......
...@@ -6,29 +6,24 @@ ...@@ -6,29 +6,24 @@
// Vectors grow and shrink dynamically as necessary. // Vectors grow and shrink dynamically as necessary.
package vector package vector
// Vector is a container for numbered sequences of elements of type interface{}. // Vector is a container for numbered sequences of elements of type interface{}.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for Vector is an empty vector ready to use. // The zero value for Vector is an empty vector ready to use.
type Vector []interface{} type Vector []interface{}
// IntVector is a container for numbered sequences of elements of type int. // IntVector is a container for numbered sequences of elements of type int.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for IntVector is an empty vector ready to use. // The zero value for IntVector is an empty vector ready to use.
type IntVector []int type IntVector []int
// StringVector is a container for numbered sequences of elements of type string. // StringVector is a container for numbered sequences of elements of type string.
// A vector's length and capacity adjusts automatically as necessary. // A vector's length and capacity adjusts automatically as necessary.
// The zero value for StringVector is an empty vector ready to use. // The zero value for StringVector is an empty vector ready to use.
type StringVector []string type StringVector []string
// Initial underlying array size // Initial underlying array size
const initialSize = 8 const initialSize = 8
// Partial sort.Interface support // Partial sort.Interface support
// LessInterface provides partial support of the sort.Interface. // LessInterface provides partial support of the sort.Interface.
...@@ -36,16 +31,13 @@ type LessInterface interface { ...@@ -36,16 +31,13 @@ type LessInterface interface {
Less(y interface{}) bool Less(y interface{}) bool
} }
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) } func (p *Vector) Less(i, j int) bool { return (*p)[i].(LessInterface).Less((*p)[j]) }
// sort.Interface support // sort.Interface support
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] } func (p *IntVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
// Less returns a boolean denoting whether the i'th element is less than the j'th element. // Less returns a boolean denoting whether the i'th element is less than the j'th element.
func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] } func (p *StringVector) Less(i, j int) bool { return (*p)[i] < (*p)[j] }
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
package vector package vector
func (p *IntVector) realloc(length, capacity int) (b []int) { func (p *IntVector) realloc(length, capacity int) (b []int) {
if capacity < initialSize { if capacity < initialSize {
capacity = initialSize capacity = initialSize
...@@ -21,7 +20,6 @@ func (p *IntVector) realloc(length, capacity int) (b []int) { ...@@ -21,7 +20,6 @@ func (p *IntVector) realloc(length, capacity int) (b []int) {
return return
} }
// Insert n elements at position i. // Insert n elements at position i.
func (p *IntVector) Expand(i, n int) { func (p *IntVector) Expand(i, n int) {
a := *p a := *p
...@@ -51,11 +49,9 @@ func (p *IntVector) Expand(i, n int) { ...@@ -51,11 +49,9 @@ func (p *IntVector) Expand(i, n int) {
*p = a *p = a
} }
// Insert n elements at the end of a vector. // Insert n elements at the end of a vector.
func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) } func (p *IntVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector. // Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards // If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length, // trailing elements. If the new length is longer than the current length,
...@@ -80,30 +76,24 @@ func (p *IntVector) Resize(length, capacity int) *IntVector { ...@@ -80,30 +76,24 @@ func (p *IntVector) Resize(length, capacity int) *IntVector {
return p return p
} }
// Len returns the number of elements in the vector. // Len returns the number of elements in the vector.
// Same as len(*p). // Same as len(*p).
func (p *IntVector) Len() int { return len(*p) } func (p *IntVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the // Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing. // maximum length the vector can grow without resizing.
// Same as cap(*p). // Same as cap(*p).
func (p *IntVector) Cap() int { return cap(*p) } func (p *IntVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector. // At returns the i'th element of the vector.
func (p *IntVector) At(i int) int { return (*p)[i] } func (p *IntVector) At(i int) int { return (*p)[i] }
// Set sets the i'th element of the vector to value x. // Set sets the i'th element of the vector to value x.
func (p *IntVector) Set(i int, x int) { (*p)[i] = x } func (p *IntVector) Set(i int, x int) { (*p)[i] = x }
// Last returns the element in the vector of highest index. // Last returns the element in the vector of highest index.
func (p *IntVector) Last() int { return (*p)[len(*p)-1] } func (p *IntVector) Last() int { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it. // Copy makes a copy of the vector and returns it.
func (p *IntVector) Copy() IntVector { func (p *IntVector) Copy() IntVector {
arr := make(IntVector, len(*p)) arr := make(IntVector, len(*p))
...@@ -111,7 +101,6 @@ func (p *IntVector) Copy() IntVector { ...@@ -111,7 +101,6 @@ func (p *IntVector) Copy() IntVector {
return arr return arr
} }
// Insert inserts into the vector an element of value x before // Insert inserts into the vector an element of value x before
// the current element at index i. // the current element at index i.
func (p *IntVector) Insert(i int, x int) { func (p *IntVector) Insert(i int, x int) {
...@@ -119,7 +108,6 @@ func (p *IntVector) Insert(i int, x int) { ...@@ -119,7 +108,6 @@ func (p *IntVector) Insert(i int, x int) {
(*p)[i] = x (*p)[i] = x
} }
// Delete deletes the i'th element of the vector. The gap is closed so the old // Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards. // element at index i+1 has index i afterwards.
func (p *IntVector) Delete(i int) { func (p *IntVector) Delete(i int) {
...@@ -132,7 +120,6 @@ func (p *IntVector) Delete(i int) { ...@@ -132,7 +120,6 @@ func (p *IntVector) Delete(i int) {
*p = a[0 : n-1] *p = a[0 : n-1]
} }
// InsertVector inserts into the vector the contents of the vector // InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion. // x such that the 0th element of x appears at index i after insertion.
func (p *IntVector) InsertVector(i int, x *IntVector) { func (p *IntVector) InsertVector(i int, x *IntVector) {
...@@ -142,7 +129,6 @@ func (p *IntVector) InsertVector(i int, x *IntVector) { ...@@ -142,7 +129,6 @@ func (p *IntVector) InsertVector(i int, x *IntVector) {
copy((*p)[i:i+len(b)], b) copy((*p)[i:i+len(b)], b)
} }
// Cut deletes elements i through j-1, inclusive. // Cut deletes elements i through j-1, inclusive.
func (p *IntVector) Cut(i, j int) { func (p *IntVector) Cut(i, j int) {
a := *p a := *p
...@@ -158,7 +144,6 @@ func (p *IntVector) Cut(i, j int) { ...@@ -158,7 +144,6 @@ func (p *IntVector) Cut(i, j int) {
*p = a[0:m] *p = a[0:m]
} }
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j]. // Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged. // The elements are copied. The original vector is unchanged.
func (p *IntVector) Slice(i, j int) *IntVector { func (p *IntVector) Slice(i, j int) *IntVector {
...@@ -168,13 +153,11 @@ func (p *IntVector) Slice(i, j int) *IntVector { ...@@ -168,13 +153,11 @@ func (p *IntVector) Slice(i, j int) *IntVector {
return &s return &s
} }
// Convenience wrappers // Convenience wrappers
// Push appends x to the end of the vector. // Push appends x to the end of the vector.
func (p *IntVector) Push(x int) { p.Insert(len(*p), x) } func (p *IntVector) Push(x int) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector. // Pop deletes the last element of the vector.
func (p *IntVector) Pop() int { func (p *IntVector) Pop() int {
a := *p a := *p
...@@ -187,18 +170,15 @@ func (p *IntVector) Pop() int { ...@@ -187,18 +170,15 @@ func (p *IntVector) Pop() int {
return x return x
} }
// AppendVector appends the entire vector x to the end of this vector. // AppendVector appends the entire vector x to the end of this vector.
func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) } func (p *IntVector) AppendVector(x *IntVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j. // Swap exchanges the elements at indexes i and j.
func (p *IntVector) Swap(i, j int) { func (p *IntVector) Swap(i, j int) {
a := *p a := *p
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
} }
// Do calls function f for each element of the vector, in order. // Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p. // The behavior of Do is undefined if f changes *p.
func (p *IntVector) Do(f func(elem int)) { func (p *IntVector) Do(f func(elem int)) {
......
...@@ -9,7 +9,6 @@ package vector ...@@ -9,7 +9,6 @@ package vector
import "testing" import "testing"
func TestIntZeroLen(t *testing.T) { func TestIntZeroLen(t *testing.T) {
a := new(IntVector) a := new(IntVector)
if a.Len() != 0 { if a.Len() != 0 {
...@@ -27,7 +26,6 @@ func TestIntZeroLen(t *testing.T) { ...@@ -27,7 +26,6 @@ func TestIntZeroLen(t *testing.T) {
} }
} }
func TestIntResize(t *testing.T) { func TestIntResize(t *testing.T) {
var a IntVector var a IntVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
...@@ -40,7 +38,6 @@ func TestIntResize(t *testing.T) { ...@@ -40,7 +38,6 @@ func TestIntResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100) checkSize(t, a.Resize(11, 100), 11, 100)
} }
func TestIntResize2(t *testing.T) { func TestIntResize2(t *testing.T) {
var a IntVector var a IntVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
...@@ -62,7 +59,6 @@ func TestIntResize2(t *testing.T) { ...@@ -62,7 +59,6 @@ func TestIntResize2(t *testing.T) {
} }
} }
func checkIntZero(t *testing.T, a *IntVector, i int) { func checkIntZero(t *testing.T, a *IntVector, i int) {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if a.At(j) == intzero { if a.At(j) == intzero {
...@@ -82,7 +78,6 @@ func checkIntZero(t *testing.T, a *IntVector, i int) { ...@@ -82,7 +78,6 @@ func checkIntZero(t *testing.T, a *IntVector, i int) {
} }
} }
func TestIntTrailingElements(t *testing.T) { func TestIntTrailingElements(t *testing.T) {
var a IntVector var a IntVector
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
...@@ -95,7 +90,6 @@ func TestIntTrailingElements(t *testing.T) { ...@@ -95,7 +90,6 @@ func TestIntTrailingElements(t *testing.T) {
checkIntZero(t, &a, 5) checkIntZero(t, &a, 5)
} }
func TestIntAccess(t *testing.T) { func TestIntAccess(t *testing.T) {
const n = 100 const n = 100
var a IntVector var a IntVector
...@@ -120,7 +114,6 @@ func TestIntAccess(t *testing.T) { ...@@ -120,7 +114,6 @@ func TestIntAccess(t *testing.T) {
} }
} }
func TestIntInsertDeleteClear(t *testing.T) { func TestIntInsertDeleteClear(t *testing.T) {
const n = 100 const n = 100
var a IntVector var a IntVector
...@@ -207,7 +200,6 @@ func TestIntInsertDeleteClear(t *testing.T) { ...@@ -207,7 +200,6 @@ func TestIntInsertDeleteClear(t *testing.T) {
} }
} }
func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) { func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
for k := i; k < j; k++ { for k := i; k < j; k++ {
if elem2IntValue(x.At(k)) != int2IntValue(elt) { if elem2IntValue(x.At(k)) != int2IntValue(elt) {
...@@ -223,7 +215,6 @@ func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) { ...@@ -223,7 +215,6 @@ func verify_sliceInt(t *testing.T, x *IntVector, elt, i, j int) {
} }
} }
func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) { func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
n := a + b + c n := a + b + c
if x.Len() != n { if x.Len() != n {
...@@ -237,7 +228,6 @@ func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) { ...@@ -237,7 +228,6 @@ func verify_patternInt(t *testing.T, x *IntVector, a, b, c int) {
verify_sliceInt(t, x, 0, a+b, n) verify_sliceInt(t, x, 0, a+b, n)
} }
func make_vectorInt(elt, len int) *IntVector { func make_vectorInt(elt, len int) *IntVector {
x := new(IntVector).Resize(len, 0) x := new(IntVector).Resize(len, 0)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
...@@ -246,7 +236,6 @@ func make_vectorInt(elt, len int) *IntVector { ...@@ -246,7 +236,6 @@ func make_vectorInt(elt, len int) *IntVector {
return x return x
} }
func TestIntInsertVector(t *testing.T) { func TestIntInsertVector(t *testing.T) {
// 1 // 1
a := make_vectorInt(0, 0) a := make_vectorInt(0, 0)
...@@ -270,7 +259,6 @@ func TestIntInsertVector(t *testing.T) { ...@@ -270,7 +259,6 @@ func TestIntInsertVector(t *testing.T) {
verify_patternInt(t, a, 8, 1000, 2) verify_patternInt(t, a, 8, 1000, 2)
} }
func TestIntDo(t *testing.T) { func TestIntDo(t *testing.T) {
const n = 25 const n = 25
const salt = 17 const salt = 17
...@@ -325,7 +313,6 @@ func TestIntDo(t *testing.T) { ...@@ -325,7 +313,6 @@ func TestIntDo(t *testing.T) {
} }
func TestIntVectorCopy(t *testing.T) { func TestIntVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector // verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10 const Len = 10
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
package vector package vector
import ( import (
"fmt" "fmt"
"sort" "sort"
...@@ -17,28 +16,23 @@ var ( ...@@ -17,28 +16,23 @@ var (
strzero string strzero string
) )
func int2Value(x int) int { return x } func int2Value(x int) int { return x }
func int2IntValue(x int) int { return x } func int2IntValue(x int) int { return x }
func int2StrValue(x int) string { return string(x) } func int2StrValue(x int) string { return string(x) }
func elem2Value(x interface{}) int { return x.(int) } func elem2Value(x interface{}) int { return x.(int) }
func elem2IntValue(x int) int { return x } func elem2IntValue(x int) int { return x }
func elem2StrValue(x string) string { return x } func elem2StrValue(x string) string { return x }
func intf2Value(x interface{}) int { return x.(int) } func intf2Value(x interface{}) int { return x.(int) }
func intf2IntValue(x interface{}) int { return x.(int) } func intf2IntValue(x interface{}) int { return x.(int) }
func intf2StrValue(x interface{}) string { return x.(string) } func intf2StrValue(x interface{}) string { return x.(string) }
type VectorInterface interface { type VectorInterface interface {
Len() int Len() int
Cap() int Cap() int
} }
func checkSize(t *testing.T, v VectorInterface, len, cap int) { func checkSize(t *testing.T, v VectorInterface, len, cap int) {
if v.Len() != len { if v.Len() != len {
t.Errorf("%T expected len = %d; found %d", v, len, v.Len()) t.Errorf("%T expected len = %d; found %d", v, len, v.Len())
...@@ -48,10 +42,8 @@ func checkSize(t *testing.T, v VectorInterface, len, cap int) { ...@@ -48,10 +42,8 @@ func checkSize(t *testing.T, v VectorInterface, len, cap int) {
} }
} }
func val(i int) int { return i*991 - 1234 } func val(i int) int { return i*991 - 1234 }
func TestSorting(t *testing.T) { func TestSorting(t *testing.T) {
const n = 100 const n = 100
...@@ -72,5 +64,4 @@ func TestSorting(t *testing.T) { ...@@ -72,5 +64,4 @@ func TestSorting(t *testing.T) {
} }
} }
func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) } func tname(x interface{}) string { return fmt.Sprintf("%T: ", x) }
...@@ -11,10 +11,8 @@ import ( ...@@ -11,10 +11,8 @@ import (
"testing" "testing"
) )
const memTestN = 1000000 const memTestN = 1000000
func s(n uint64) string { func s(n uint64) string {
str := fmt.Sprintf("%d", n) str := fmt.Sprintf("%d", n)
lens := len(str) lens := len(str)
...@@ -31,7 +29,6 @@ func s(n uint64) string { ...@@ -31,7 +29,6 @@ func s(n uint64) string {
return strings.Join(a, " ") return strings.Join(a, " ")
} }
func TestVectorNums(t *testing.T) { func TestVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
...@@ -52,7 +49,6 @@ func TestVectorNums(t *testing.T) { ...@@ -52,7 +49,6 @@ func TestVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func TestIntVectorNums(t *testing.T) { func TestIntVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
...@@ -73,7 +69,6 @@ func TestIntVectorNums(t *testing.T) { ...@@ -73,7 +69,6 @@ func TestIntVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func TestStringVectorNums(t *testing.T) { func TestStringVectorNums(t *testing.T) {
if testing.Short() { if testing.Short() {
return return
...@@ -94,7 +89,6 @@ func TestStringVectorNums(t *testing.T) { ...@@ -94,7 +89,6 @@ func TestStringVectorNums(t *testing.T) {
t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN) t.Logf("%T.Push(%#v), n = %s: Alloc/n = %.2f\n", v, c, s(memTestN), float64(n)/memTestN)
} }
func BenchmarkVectorNums(b *testing.B) { func BenchmarkVectorNums(b *testing.B) {
c := int(0) c := int(0)
var v Vector var v Vector
...@@ -106,7 +100,6 @@ func BenchmarkVectorNums(b *testing.B) { ...@@ -106,7 +100,6 @@ func BenchmarkVectorNums(b *testing.B) {
} }
} }
func BenchmarkIntVectorNums(b *testing.B) { func BenchmarkIntVectorNums(b *testing.B) {
c := int(0) c := int(0)
var v IntVector var v IntVector
...@@ -118,7 +111,6 @@ func BenchmarkIntVectorNums(b *testing.B) { ...@@ -118,7 +111,6 @@ func BenchmarkIntVectorNums(b *testing.B) {
} }
} }
func BenchmarkStringVectorNums(b *testing.B) { func BenchmarkStringVectorNums(b *testing.B) {
c := "" c := ""
var v StringVector var v StringVector
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
package vector package vector
func (p *StringVector) realloc(length, capacity int) (b []string) { func (p *StringVector) realloc(length, capacity int) (b []string) {
if capacity < initialSize { if capacity < initialSize {
capacity = initialSize capacity = initialSize
...@@ -21,7 +20,6 @@ func (p *StringVector) realloc(length, capacity int) (b []string) { ...@@ -21,7 +20,6 @@ func (p *StringVector) realloc(length, capacity int) (b []string) {
return return
} }
// Insert n elements at position i. // Insert n elements at position i.
func (p *StringVector) Expand(i, n int) { func (p *StringVector) Expand(i, n int) {
a := *p a := *p
...@@ -51,11 +49,9 @@ func (p *StringVector) Expand(i, n int) { ...@@ -51,11 +49,9 @@ func (p *StringVector) Expand(i, n int) {
*p = a *p = a
} }
// Insert n elements at the end of a vector. // Insert n elements at the end of a vector.
func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) } func (p *StringVector) Extend(n int) { p.Expand(len(*p), n) }
// Resize changes the length and capacity of a vector. // Resize changes the length and capacity of a vector.
// If the new length is shorter than the current length, Resize discards // If the new length is shorter than the current length, Resize discards
// trailing elements. If the new length is longer than the current length, // trailing elements. If the new length is longer than the current length,
...@@ -80,30 +76,24 @@ func (p *StringVector) Resize(length, capacity int) *StringVector { ...@@ -80,30 +76,24 @@ func (p *StringVector) Resize(length, capacity int) *StringVector {
return p return p
} }
// Len returns the number of elements in the vector. // Len returns the number of elements in the vector.
// Same as len(*p). // Same as len(*p).
func (p *StringVector) Len() int { return len(*p) } func (p *StringVector) Len() int { return len(*p) }
// Cap returns the capacity of the vector; that is, the // Cap returns the capacity of the vector; that is, the
// maximum length the vector can grow without resizing. // maximum length the vector can grow without resizing.
// Same as cap(*p). // Same as cap(*p).
func (p *StringVector) Cap() int { return cap(*p) } func (p *StringVector) Cap() int { return cap(*p) }
// At returns the i'th element of the vector. // At returns the i'th element of the vector.
func (p *StringVector) At(i int) string { return (*p)[i] } func (p *StringVector) At(i int) string { return (*p)[i] }
// Set sets the i'th element of the vector to value x. // Set sets the i'th element of the vector to value x.
func (p *StringVector) Set(i int, x string) { (*p)[i] = x } func (p *StringVector) Set(i int, x string) { (*p)[i] = x }
// Last returns the element in the vector of highest index. // Last returns the element in the vector of highest index.
func (p *StringVector) Last() string { return (*p)[len(*p)-1] } func (p *StringVector) Last() string { return (*p)[len(*p)-1] }
// Copy makes a copy of the vector and returns it. // Copy makes a copy of the vector and returns it.
func (p *StringVector) Copy() StringVector { func (p *StringVector) Copy() StringVector {
arr := make(StringVector, len(*p)) arr := make(StringVector, len(*p))
...@@ -111,7 +101,6 @@ func (p *StringVector) Copy() StringVector { ...@@ -111,7 +101,6 @@ func (p *StringVector) Copy() StringVector {
return arr return arr
} }
// Insert inserts into the vector an element of value x before // Insert inserts into the vector an element of value x before
// the current element at index i. // the current element at index i.
func (p *StringVector) Insert(i int, x string) { func (p *StringVector) Insert(i int, x string) {
...@@ -119,7 +108,6 @@ func (p *StringVector) Insert(i int, x string) { ...@@ -119,7 +108,6 @@ func (p *StringVector) Insert(i int, x string) {
(*p)[i] = x (*p)[i] = x
} }
// Delete deletes the i'th element of the vector. The gap is closed so the old // Delete deletes the i'th element of the vector. The gap is closed so the old
// element at index i+1 has index i afterwards. // element at index i+1 has index i afterwards.
func (p *StringVector) Delete(i int) { func (p *StringVector) Delete(i int) {
...@@ -132,7 +120,6 @@ func (p *StringVector) Delete(i int) { ...@@ -132,7 +120,6 @@ func (p *StringVector) Delete(i int) {
*p = a[0 : n-1] *p = a[0 : n-1]
} }
// InsertVector inserts into the vector the contents of the vector // InsertVector inserts into the vector the contents of the vector
// x such that the 0th element of x appears at index i after insertion. // x such that the 0th element of x appears at index i after insertion.
func (p *StringVector) InsertVector(i int, x *StringVector) { func (p *StringVector) InsertVector(i int, x *StringVector) {
...@@ -142,7 +129,6 @@ func (p *StringVector) InsertVector(i int, x *StringVector) { ...@@ -142,7 +129,6 @@ func (p *StringVector) InsertVector(i int, x *StringVector) {
copy((*p)[i:i+len(b)], b) copy((*p)[i:i+len(b)], b)
} }
// Cut deletes elements i through j-1, inclusive. // Cut deletes elements i through j-1, inclusive.
func (p *StringVector) Cut(i, j int) { func (p *StringVector) Cut(i, j int) {
a := *p a := *p
...@@ -158,7 +144,6 @@ func (p *StringVector) Cut(i, j int) { ...@@ -158,7 +144,6 @@ func (p *StringVector) Cut(i, j int) {
*p = a[0:m] *p = a[0:m]
} }
// Slice returns a new sub-vector by slicing the old one to extract slice [i:j]. // Slice returns a new sub-vector by slicing the old one to extract slice [i:j].
// The elements are copied. The original vector is unchanged. // The elements are copied. The original vector is unchanged.
func (p *StringVector) Slice(i, j int) *StringVector { func (p *StringVector) Slice(i, j int) *StringVector {
...@@ -168,13 +153,11 @@ func (p *StringVector) Slice(i, j int) *StringVector { ...@@ -168,13 +153,11 @@ func (p *StringVector) Slice(i, j int) *StringVector {
return &s return &s
} }
// Convenience wrappers // Convenience wrappers
// Push appends x to the end of the vector. // Push appends x to the end of the vector.
func (p *StringVector) Push(x string) { p.Insert(len(*p), x) } func (p *StringVector) Push(x string) { p.Insert(len(*p), x) }
// Pop deletes the last element of the vector. // Pop deletes the last element of the vector.
func (p *StringVector) Pop() string { func (p *StringVector) Pop() string {
a := *p a := *p
...@@ -187,18 +170,15 @@ func (p *StringVector) Pop() string { ...@@ -187,18 +170,15 @@ func (p *StringVector) Pop() string {
return x return x
} }
// AppendVector appends the entire vector x to the end of this vector. // AppendVector appends the entire vector x to the end of this vector.
func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) } func (p *StringVector) AppendVector(x *StringVector) { p.InsertVector(len(*p), x) }
// Swap exchanges the elements at indexes i and j. // Swap exchanges the elements at indexes i and j.
func (p *StringVector) Swap(i, j int) { func (p *StringVector) Swap(i, j int) {
a := *p a := *p
a[i], a[j] = a[j], a[i] a[i], a[j] = a[j], a[i]
} }
// Do calls function f for each element of the vector, in order. // Do calls function f for each element of the vector, in order.
// The behavior of Do is undefined if f changes *p. // The behavior of Do is undefined if f changes *p.
func (p *StringVector) Do(f func(elem string)) { func (p *StringVector) Do(f func(elem string)) {
......
...@@ -9,7 +9,6 @@ package vector ...@@ -9,7 +9,6 @@ package vector
import "testing" import "testing"
func TestStrZeroLen(t *testing.T) { func TestStrZeroLen(t *testing.T) {
a := new(StringVector) a := new(StringVector)
if a.Len() != 0 { if a.Len() != 0 {
...@@ -27,7 +26,6 @@ func TestStrZeroLen(t *testing.T) { ...@@ -27,7 +26,6 @@ func TestStrZeroLen(t *testing.T) {
} }
} }
func TestStrResize(t *testing.T) { func TestStrResize(t *testing.T) {
var a StringVector var a StringVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
...@@ -40,7 +38,6 @@ func TestStrResize(t *testing.T) { ...@@ -40,7 +38,6 @@ func TestStrResize(t *testing.T) {
checkSize(t, a.Resize(11, 100), 11, 100) checkSize(t, a.Resize(11, 100), 11, 100)
} }
func TestStrResize2(t *testing.T) { func TestStrResize2(t *testing.T) {
var a StringVector var a StringVector
checkSize(t, &a, 0, 0) checkSize(t, &a, 0, 0)
...@@ -62,7 +59,6 @@ func TestStrResize2(t *testing.T) { ...@@ -62,7 +59,6 @@ func TestStrResize2(t *testing.T) {
} }
} }
func checkStrZero(t *testing.T, a *StringVector, i int) { func checkStrZero(t *testing.T, a *StringVector, i int) {
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
if a.At(j) == strzero { if a.At(j) == strzero {
...@@ -82,7 +78,6 @@ func checkStrZero(t *testing.T, a *StringVector, i int) { ...@@ -82,7 +78,6 @@ func checkStrZero(t *testing.T, a *StringVector, i int) {
} }
} }
func TestStrTrailingElements(t *testing.T) { func TestStrTrailingElements(t *testing.T) {
var a StringVector var a StringVector
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
...@@ -95,7 +90,6 @@ func TestStrTrailingElements(t *testing.T) { ...@@ -95,7 +90,6 @@ func TestStrTrailingElements(t *testing.T) {
checkStrZero(t, &a, 5) checkStrZero(t, &a, 5)
} }
func TestStrAccess(t *testing.T) { func TestStrAccess(t *testing.T) {
const n = 100 const n = 100
var a StringVector var a StringVector
...@@ -120,7 +114,6 @@ func TestStrAccess(t *testing.T) { ...@@ -120,7 +114,6 @@ func TestStrAccess(t *testing.T) {
} }
} }
func TestStrInsertDeleteClear(t *testing.T) { func TestStrInsertDeleteClear(t *testing.T) {
const n = 100 const n = 100
var a StringVector var a StringVector
...@@ -207,7 +200,6 @@ func TestStrInsertDeleteClear(t *testing.T) { ...@@ -207,7 +200,6 @@ func TestStrInsertDeleteClear(t *testing.T) {
} }
} }
func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) { func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
for k := i; k < j; k++ { for k := i; k < j; k++ {
if elem2StrValue(x.At(k)) != int2StrValue(elt) { if elem2StrValue(x.At(k)) != int2StrValue(elt) {
...@@ -223,7 +215,6 @@ func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) { ...@@ -223,7 +215,6 @@ func verify_sliceStr(t *testing.T, x *StringVector, elt, i, j int) {
} }
} }
func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) { func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
n := a + b + c n := a + b + c
if x.Len() != n { if x.Len() != n {
...@@ -237,7 +228,6 @@ func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) { ...@@ -237,7 +228,6 @@ func verify_patternStr(t *testing.T, x *StringVector, a, b, c int) {
verify_sliceStr(t, x, 0, a+b, n) verify_sliceStr(t, x, 0, a+b, n)
} }
func make_vectorStr(elt, len int) *StringVector { func make_vectorStr(elt, len int) *StringVector {
x := new(StringVector).Resize(len, 0) x := new(StringVector).Resize(len, 0)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
...@@ -246,7 +236,6 @@ func make_vectorStr(elt, len int) *StringVector { ...@@ -246,7 +236,6 @@ func make_vectorStr(elt, len int) *StringVector {
return x return x
} }
func TestStrInsertVector(t *testing.T) { func TestStrInsertVector(t *testing.T) {
// 1 // 1
a := make_vectorStr(0, 0) a := make_vectorStr(0, 0)
...@@ -270,7 +259,6 @@ func TestStrInsertVector(t *testing.T) { ...@@ -270,7 +259,6 @@ func TestStrInsertVector(t *testing.T) {
verify_patternStr(t, a, 8, 1000, 2) verify_patternStr(t, a, 8, 1000, 2)
} }
func TestStrDo(t *testing.T) { func TestStrDo(t *testing.T) {
const n = 25 const n = 25
const salt = 17 const salt = 17
...@@ -325,7 +313,6 @@ func TestStrDo(t *testing.T) { ...@@ -325,7 +313,6 @@ func TestStrDo(t *testing.T) {
} }
func TestStrVectorCopy(t *testing.T) { func TestStrVectorCopy(t *testing.T) {
// verify Copy() returns a copy, not simply a slice of the original vector // verify Copy() returns a copy, not simply a slice of the original vector
const Len = 10 const Len = 10
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment