Commit be47d6ec by Ian Lance Taylor

libgo: Update to Go 1.1.1.

From-SVN: r200974
parent efb30cde
229081515358
a7bd9a33067b
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.
......@@ -9,12 +9,14 @@
// References:
// http://www.freebsd.org/cgi/man.cgi?query=tar&sektion=5
// http://www.gnu.org/software/tar/manual/html_node/Standard.html
// http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html
package tar
import (
"errors"
"fmt"
"os"
"path"
"time"
)
......@@ -33,6 +35,8 @@ const (
TypeCont = '7' // reserved
TypeXHeader = 'x' // extended header
TypeXGlobalHeader = 'g' // global extended header
TypeGNULongName = 'L' // Next file has a long name
TypeGNULongLink = 'K' // Next file symlinks to a file w/ a long name
)
// A Header represents a single header in a tar archive.
......@@ -54,55 +58,172 @@ type Header struct {
ChangeTime time.Time // status change time
}
// File name constants from the tar spec.
const (
fileNameSize = 100 // Maximum number of bytes in a standard tar name.
fileNamePrefixSize = 155 // Maximum number of ustar extension bytes.
)
// FileInfo returns an os.FileInfo for the Header.
func (h *Header) FileInfo() os.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements os.FileInfo.
type headerFileInfo struct {
h *Header
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() interface{} { return fi.h }
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Clean(fi.h.Name)
}
return fi.h.Name
}
// Mode returns the permission and mode bits for the headerFileInfo.
func (fi headerFileInfo) Mode() (mode os.FileMode) {
// Set file permission bits.
mode = os.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&c_ISUID != 0 {
// setuid
mode |= os.ModeSetuid
}
if fi.h.Mode&c_ISGID != 0 {
// setgid
mode |= os.ModeSetgid
}
if fi.h.Mode&c_ISVTX != 0 {
// sticky
mode |= os.ModeSticky
}
// Set file mode bits.
// clear perm, setuid, setgid and sticky bits.
m := os.FileMode(fi.h.Mode) &^ 07777
if m == c_ISDIR {
// directory
mode |= os.ModeDir
}
if m == c_ISFIFO {
// named pipe (FIFO)
mode |= os.ModeNamedPipe
}
if m == c_ISLNK {
// symbolic link
mode |= os.ModeSymlink
}
if m == c_ISBLK {
// device file
mode |= os.ModeDevice
}
if m == c_ISCHR {
// Unix character device
mode |= os.ModeDevice
mode |= os.ModeCharDevice
}
if m == c_ISSOCK {
// Unix domain socket
mode |= os.ModeSocket
}
switch fi.h.Typeflag {
case TypeLink, TypeSymlink:
// hard link, symbolic link
mode |= os.ModeSymlink
case TypeChar:
// character device node
mode |= os.ModeDevice
mode |= os.ModeCharDevice
case TypeBlock:
// block device node
mode |= os.ModeDevice
case TypeDir:
// directory
mode |= os.ModeDir
case TypeFifo:
// fifo node
mode |= os.ModeNamedPipe
}
return mode
}
// sysStat, if non-nil, populates h from system-dependent fields of fi.
var sysStat func(fi os.FileInfo, h *Header) error
// Mode constants from the tar spec.
const (
c_ISDIR = 040000
c_ISFIFO = 010000
c_ISREG = 0100000
c_ISLNK = 0120000
c_ISBLK = 060000
c_ISCHR = 020000
c_ISSOCK = 0140000
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
// FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("tar: FileInfo is nil")
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
ModTime: fi.ModTime(),
Mode: int64(fi.Mode().Perm()), // or'd with c_IS* constants later
Mode: int64(fm.Perm()), // or'd with c_IS* constants later
}
switch {
case fi.Mode()&os.ModeType == 0:
case fm.IsRegular():
h.Mode |= c_ISREG
h.Typeflag = TypeReg
h.Size = fi.Size()
case fi.IsDir():
h.Typeflag = TypeDir
h.Mode |= c_ISDIR
case fi.Mode()&os.ModeSymlink != 0:
h.Name += "/"
case fm&os.ModeSymlink != 0:
h.Typeflag = TypeSymlink
h.Mode |= c_ISLNK
h.Linkname = link
case fi.Mode()&os.ModeDevice != 0:
if fi.Mode()&os.ModeCharDevice != 0 {
case fm&os.ModeDevice != 0:
if fm&os.ModeCharDevice != 0 {
h.Mode |= c_ISCHR
h.Typeflag = TypeChar
} else {
h.Mode |= c_ISBLK
h.Typeflag = TypeBlock
}
case fi.Mode()&os.ModeSocket != 0:
case fm&os.ModeNamedPipe != 0:
h.Typeflag = TypeFifo
h.Mode |= c_ISFIFO
case fm&os.ModeSocket != 0:
h.Mode |= c_ISSOCK
default:
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fi.Mode())
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fm)
}
if fm&os.ModeSetuid != 0 {
h.Mode |= c_ISUID
}
if fm&os.ModeSetgid != 0 {
h.Mode |= c_ISGID
}
if fm&os.ModeSticky != 0 {
h.Mode |= c_ISVTX
}
if sysStat != nil {
return h, sysStat(fi, h)
......
......@@ -14,6 +14,7 @@ import (
"io/ioutil"
"os"
"strconv"
"strings"
"time"
)
......@@ -21,24 +22,12 @@ var (
ErrHeader = errors.New("archive/tar: invalid tar header")
)
const maxNanoSecondIntSize = 9
// A Reader provides sequential access to the contents of a tar archive.
// A tar archive consists of a sequence of files.
// The Next method advances to the next file in the archive (including the first),
// and then it can be treated as an io.Reader to access the file's data.
//
// Example:
// tr := tar.NewReader(r)
// for {
// hdr, err := tr.Next()
// if err == io.EOF {
// // end of tar archive
// break
// }
// if err != nil {
// // handle error
// }
// io.Copy(data, tr)
// }
type Reader struct {
r io.Reader
err error
......@@ -55,13 +44,183 @@ func (tr *Reader) Next() (*Header, error) {
if tr.err == nil {
tr.skipUnread()
}
if tr.err == nil {
if tr.err != nil {
return hdr, tr.err
}
hdr = tr.readHeader()
if hdr == nil {
return hdr, tr.err
}
// Check for PAX/GNU header.
switch hdr.Typeflag {
case TypeXHeader:
// PAX extended header
headers, err := parsePAX(tr)
if err != nil {
return nil, err
}
// We actually read the whole file,
// but this skips alignment padding
tr.skipUnread()
hdr = tr.readHeader()
mergePAX(hdr, headers)
return hdr, nil
case TypeGNULongName:
// We have a GNU long name header. Its contents are the real file name.
realname, err := ioutil.ReadAll(tr)
if err != nil {
return nil, err
}
hdr, err := tr.Next()
hdr.Name = cString(realname)
return hdr, err
case TypeGNULongLink:
// We have a GNU long link header.
realname, err := ioutil.ReadAll(tr)
if err != nil {
return nil, err
}
hdr, err := tr.Next()
hdr.Linkname = cString(realname)
return hdr, err
}
return hdr, tr.err
}
// Parse bytes as a NUL-terminated C-style string.
// mergePAX merges well known headers according to PAX standard.
// In general headers with the same name as those found
// in the header struct overwrite those found in the header
// struct with higher precision or longer values. Esp. useful
// for name and linkname fields.
func mergePAX(hdr *Header, headers map[string]string) error {
for k, v := range headers {
switch k {
case "path":
hdr.Name = v
case "linkpath":
hdr.Linkname = v
case "gname":
hdr.Gname = v
case "uname":
hdr.Uname = v
case "uid":
uid, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return err
}
hdr.Uid = int(uid)
case "gid":
gid, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return err
}
hdr.Gid = int(gid)
case "atime":
t, err := parsePAXTime(v)
if err != nil {
return err
}
hdr.AccessTime = t
case "mtime":
t, err := parsePAXTime(v)
if err != nil {
return err
}
hdr.ModTime = t
case "ctime":
t, err := parsePAXTime(v)
if err != nil {
return err
}
hdr.ChangeTime = t
case "size":
size, err := strconv.ParseInt(v, 10, 0)
if err != nil {
return err
}
hdr.Size = int64(size)
}
}
return nil
}
// parsePAXTime takes a string of the form %d.%d as described in
// the PAX specification.
func parsePAXTime(t string) (time.Time, error) {
buf := []byte(t)
pos := bytes.IndexByte(buf, '.')
var seconds, nanoseconds int64
var err error
if pos == -1 {
seconds, err = strconv.ParseInt(t, 10, 0)
if err != nil {
return time.Time{}, err
}
} else {
seconds, err = strconv.ParseInt(string(buf[:pos]), 10, 0)
if err != nil {
return time.Time{}, err
}
nano_buf := string(buf[pos+1:])
// Pad as needed before converting to a decimal.
// For example .030 -> .030000000 -> 30000000 nanoseconds
if len(nano_buf) < maxNanoSecondIntSize {
// Right pad
nano_buf += strings.Repeat("0", maxNanoSecondIntSize-len(nano_buf))
} else if len(nano_buf) > maxNanoSecondIntSize {
// Right truncate
nano_buf = nano_buf[:maxNanoSecondIntSize]
}
nanoseconds, err = strconv.ParseInt(string(nano_buf), 10, 0)
if err != nil {
return time.Time{}, err
}
}
ts := time.Unix(seconds, nanoseconds)
return ts, nil
}
// parsePAX parses PAX headers.
// If an extended header (type 'x') is invalid, ErrHeader is returned
func parsePAX(r io.Reader) (map[string]string, error) {
buf, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
headers := make(map[string]string)
// Each record is constructed as
// "%d %s=%s\n", length, keyword, value
for len(buf) > 0 {
// or the header was empty to start with.
var sp int
// The size field ends at the first space.
sp = bytes.IndexByte(buf, ' ')
if sp == -1 {
return nil, ErrHeader
}
// Parse the first token as a decimal integer.
n, err := strconv.ParseInt(string(buf[:sp]), 10, 0)
if err != nil {
return nil, ErrHeader
}
// Extract everything between the decimal and the n -1 on the
// beginning to to eat the ' ', -1 on the end to skip the newline.
var record []byte
record, buf = buf[sp+1:n-1], buf[n:]
// The first equals is guaranteed to mark the end of the key.
// Everything else is value.
eq := bytes.IndexByte(record, '=')
if eq == -1 {
return nil, ErrHeader
}
key, value := record[:eq], record[eq+1:]
headers[string(key)] = string(value)
}
return headers, nil
}
// cString parses bytes as a NUL-terminated C-style string.
// If a NUL byte is not found then the whole slice is returned as a string.
func cString(b []byte) string {
n := 0
......@@ -99,7 +258,7 @@ func (tr *Reader) octal(b []byte) int64 {
return int64(x)
}
// Skip any unread bytes in the existing file entry, as well as any alignment padding.
// skipUnread skips any unread bytes in the existing file entry, as well as any alignment padding.
func (tr *Reader) skipUnread() {
nr := tr.nb + tr.pad // number of bytes to skip
tr.nb, tr.pad = 0, 0
......
......@@ -10,6 +10,8 @@ import (
"fmt"
"io"
"os"
"reflect"
"strings"
"testing"
"time"
)
......@@ -108,6 +110,38 @@ var untarTests = []*untarTest{
},
},
},
{
file: "testdata/pax.tar",
headers: []*Header{
{
Name: "a/123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100",
Mode: 0664,
Uid: 1000,
Gid: 1000,
Uname: "shane",
Gname: "shane",
Size: 7,
ModTime: time.Unix(1350244992, 23960108),
ChangeTime: time.Unix(1350244992, 23960108),
AccessTime: time.Unix(1350244992, 23960108),
Typeflag: TypeReg,
},
{
Name: "a/b",
Mode: 0777,
Uid: 1000,
Gid: 1000,
Uname: "shane",
Gname: "shane",
Size: 0,
ModTime: time.Unix(1350266320, 910238425),
ChangeTime: time.Unix(1350266320, 910238425),
AccessTime: time.Unix(1350266320, 910238425),
Typeflag: TypeSymlink,
Linkname: "123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100",
},
},
},
}
func TestReader(t *testing.T) {
......@@ -133,7 +167,7 @@ testLoop:
}
hdr, err := tr.Next()
if err == io.EOF {
break
continue testLoop
}
if hdr != nil || err != nil {
t.Errorf("test %d: Unexpected entry or error: hdr=%v err=%v", i, hdr, err)
......@@ -260,3 +294,73 @@ func TestNonSeekable(t *testing.T) {
t.Errorf("Didn't process all files\nexpected: %d\nprocessed %d\n", len(test.headers), nread)
}
}
func TestParsePAXHeader(t *testing.T) {
paxTests := [][3]string{
{"a", "a=name", "10 a=name\n"}, // Test case involving multiple acceptable lengths
{"a", "a=name", "9 a=name\n"}, // Test case involving multiple acceptable length
{"mtime", "mtime=1350244992.023960108", "30 mtime=1350244992.023960108\n"}}
for _, test := range paxTests {
key, expected, raw := test[0], test[1], test[2]
reader := bytes.NewBuffer([]byte(raw))
headers, err := parsePAX(reader)
if err != nil {
t.Errorf("Couldn't parse correctly formatted headers: %v", err)
continue
}
if strings.EqualFold(headers[key], expected) {
t.Errorf("mtime header incorrectly parsed: got %s, wanted %s", headers[key], expected)
continue
}
trailer := make([]byte, 100)
n, err := reader.Read(trailer)
if err != io.EOF || n != 0 {
t.Error("Buffer wasn't consumed")
}
}
badHeader := bytes.NewBuffer([]byte("3 somelongkey="))
if _, err := parsePAX(badHeader); err != ErrHeader {
t.Fatal("Unexpected success when parsing bad header")
}
}
func TestParsePAXTime(t *testing.T) {
// Some valid PAX time values
timestamps := map[string]time.Time{
"1350244992.023960108": time.Unix(1350244992, 23960108), // The commoon case
"1350244992.02396010": time.Unix(1350244992, 23960100), // Lower precision value
"1350244992.0239601089": time.Unix(1350244992, 23960108), // Higher precision value
"1350244992": time.Unix(1350244992, 0), // Low precision value
}
for input, expected := range timestamps {
ts, err := parsePAXTime(input)
if err != nil {
t.Fatal(err)
}
if !ts.Equal(expected) {
t.Fatalf("Time parsing failure %s %s", ts, expected)
}
}
}
func TestMergePAX(t *testing.T) {
hdr := new(Header)
// Test a string, integer, and time based value.
headers := map[string]string{
"path": "a/b/c",
"uid": "1000",
"mtime": "1350244992.023960108",
}
err := mergePAX(hdr, headers)
if err != nil {
t.Fatal(err)
}
want := &Header{
Name: "a/b/c",
Uid: 1000,
ModTime: time.Unix(1350244992, 23960108),
}
if !reflect.DeepEqual(hdr, want) {
t.Errorf("incorrect merge: got %+v, want %+v", hdr, want)
}
}
......@@ -14,13 +14,13 @@ import (
)
func TestFileInfoHeader(t *testing.T) {
fi, err := os.Lstat("testdata/small.txt")
fi, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
h, err := FileInfoHeader(fi, "")
if err != nil {
t.Fatalf("on small.txt: %v", err)
t.Fatalf("FileInfoHeader: %v", err)
}
if g, e := h.Name, "small.txt"; g != e {
t.Errorf("Name = %q; want %q", g, e)
......@@ -36,6 +36,30 @@ func TestFileInfoHeader(t *testing.T) {
}
}
func TestFileInfoHeaderDir(t *testing.T) {
fi, err := os.Stat("testdata")
if err != nil {
t.Fatal(err)
}
h, err := FileInfoHeader(fi, "")
if err != nil {
t.Fatalf("FileInfoHeader: %v", err)
}
if g, e := h.Name, "testdata/"; g != e {
t.Errorf("Name = %q; want %q", g, e)
}
// Ignoring c_ISGID for golang.org/issue/4867
if g, e := h.Mode&^c_ISGID, int64(fi.Mode().Perm())|c_ISDIR; g != e {
t.Errorf("Mode = %#o; want %#o", g, e)
}
if g, e := h.Size, int64(0); g != e {
t.Errorf("Size = %v; want %v", g, e)
}
if g, e := h.ModTime, fi.ModTime(); !g.Equal(e) {
t.Errorf("ModTime = %v; want %v", g, e)
}
}
func TestFileInfoHeaderSymlink(t *testing.T) {
h, err := FileInfoHeader(symlink{}, "some-target")
if err != nil {
......@@ -98,3 +122,150 @@ func TestRoundTrip(t *testing.T) {
t.Errorf("Data mismatch.\n got %q\nwant %q", rData, data)
}
}
type headerRoundTripTest struct {
h *Header
fm os.FileMode
}
func TestHeaderRoundTrip(t *testing.T) {
golden := []headerRoundTripTest{
// regular file.
{
h: &Header{
Name: "test.txt",
Mode: 0644 | c_ISREG,
Size: 12,
ModTime: time.Unix(1360600916, 0),
Typeflag: TypeReg,
},
fm: 0644,
},
// hard link.
{
h: &Header{
Name: "hard.txt",
Mode: 0644 | c_ISLNK,
Size: 0,
ModTime: time.Unix(1360600916, 0),
Typeflag: TypeLink,
},
fm: 0644 | os.ModeSymlink,
},
// symbolic link.
{
h: &Header{
Name: "link.txt",
Mode: 0777 | c_ISLNK,
Size: 0,
ModTime: time.Unix(1360600852, 0),
Typeflag: TypeSymlink,
},
fm: 0777 | os.ModeSymlink,
},
// character device node.
{
h: &Header{
Name: "dev/null",
Mode: 0666 | c_ISCHR,
Size: 0,
ModTime: time.Unix(1360578951, 0),
Typeflag: TypeChar,
},
fm: 0666 | os.ModeDevice | os.ModeCharDevice,
},
// block device node.
{
h: &Header{
Name: "dev/sda",
Mode: 0660 | c_ISBLK,
Size: 0,
ModTime: time.Unix(1360578954, 0),
Typeflag: TypeBlock,
},
fm: 0660 | os.ModeDevice,
},
// directory.
{
h: &Header{
Name: "dir/",
Mode: 0755 | c_ISDIR,
Size: 0,
ModTime: time.Unix(1360601116, 0),
Typeflag: TypeDir,
},
fm: 0755 | os.ModeDir,
},
// fifo node.
{
h: &Header{
Name: "dev/initctl",
Mode: 0600 | c_ISFIFO,
Size: 0,
ModTime: time.Unix(1360578949, 0),
Typeflag: TypeFifo,
},
fm: 0600 | os.ModeNamedPipe,
},
// setuid.
{
h: &Header{
Name: "bin/su",
Mode: 0755 | c_ISREG | c_ISUID,
Size: 23232,
ModTime: time.Unix(1355405093, 0),
Typeflag: TypeReg,
},
fm: 0755 | os.ModeSetuid,
},
// setguid.
{
h: &Header{
Name: "group.txt",
Mode: 0750 | c_ISREG | c_ISGID,
Size: 0,
ModTime: time.Unix(1360602346, 0),
Typeflag: TypeReg,
},
fm: 0750 | os.ModeSetgid,
},
// sticky.
{
h: &Header{
Name: "sticky.txt",
Mode: 0600 | c_ISREG | c_ISVTX,
Size: 7,
ModTime: time.Unix(1360602540, 0),
Typeflag: TypeReg,
},
fm: 0600 | os.ModeSticky,
},
}
for i, g := range golden {
fi := g.h.FileInfo()
h2, err := FileInfoHeader(fi, "")
if err != nil {
t.Error(err)
continue
}
if got, want := h2.Name, g.h.Name; got != want {
t.Errorf("i=%d: Name: got %v, want %v", i, got, want)
}
if got, want := h2.Size, g.h.Size; got != want {
t.Errorf("i=%d: Size: got %v, want %v", i, got, want)
}
if got, want := h2.Mode, g.h.Mode; got != want {
t.Errorf("i=%d: Mode: got %o, want %o", i, got, want)
}
if got, want := fi.Mode(), g.fm; got != want {
t.Errorf("i=%d: fi.Mode: got %o, want %o", i, got, want)
}
if got, want := h2.ModTime, g.h.ModTime; got != want {
t.Errorf("i=%d: ModTime: got %v, want %v", i, got, want)
}
if sysh, ok := fi.Sys().(*Header); !ok || sysh != g.h {
t.Errorf("i=%d: Sys didn't return original *Header", i)
}
}
}
......@@ -8,10 +8,14 @@ package tar
// - catch more errors (no first header, etc.)
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path"
"strconv"
"strings"
"time"
)
......@@ -19,23 +23,13 @@ var (
ErrWriteTooLong = errors.New("archive/tar: write too long")
ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("archive/tar: write after close")
errNameTooLong = errors.New("archive/tar: name too long")
)
// A Writer provides sequential writing of a tar archive in POSIX.1 format.
// A tar archive consists of a sequence of files.
// Call WriteHeader to begin a new file, and then call Write to supply that file's data,
// writing at most hdr.Size bytes in total.
//
// Example:
// tw := tar.NewWriter(w)
// hdr := new(tar.Header)
// hdr.Size = length of data in bytes
// // populate other hdr fields as desired
// if err := tw.WriteHeader(hdr); err != nil {
// // handle error
// }
// io.Copy(tw, data)
// tw.Close()
type Writer struct {
w io.Writer
err error
......@@ -130,15 +124,31 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
if tw.err != nil {
return tw.err
}
// Decide whether or not to use PAX extensions
// TODO(shanemhansen): we might want to use PAX headers for
// subsecond time resolution, but for now let's just capture
// the long name/long symlink use case.
suffix := hdr.Name
prefix := ""
if len(hdr.Name) > fileNameSize || len(hdr.Linkname) > fileNameSize {
var err error
prefix, suffix, err = tw.splitUSTARLongName(hdr.Name)
// Either we were unable to pack the long name into ustar format
// or the link name is too long; use PAX headers.
if err == errNameTooLong || len(hdr.Linkname) > fileNameSize {
if err := tw.writePAXHeader(hdr); err != nil {
return err
}
} else if err != nil {
return err
}
}
tw.nb = int64(hdr.Size)
tw.pad = -tw.nb & (blockSize - 1) // blockSize is a power of two
header := make([]byte, blockSize)
s := slicer(header)
// TODO(dsymonds): handle names longer than 100 chars
copy(s.next(100), []byte(hdr.Name))
tw.cString(s.next(fileNameSize), suffix)
// Handle out of range ModTime carefully.
var modTime int64
......@@ -159,11 +169,15 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
tw.cString(s.next(32), hdr.Gname) // 297:329
tw.numeric(s.next(8), hdr.Devmajor) // 329:337
tw.numeric(s.next(8), hdr.Devminor) // 337:345
tw.cString(s.next(155), prefix) // 345:500
// Use the GNU magic instead of POSIX magic if we used any GNU extensions.
if tw.usedBinary {
copy(header[257:265], []byte("ustar \x00"))
}
// Use the ustar magic if we used ustar long names.
if len(prefix) > 0 {
copy(header[257:265], []byte("ustar\000"))
}
// The chksum field is terminated by a NUL and a space.
// This is different from the other octal fields.
......@@ -181,6 +195,78 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
return tw.err
}
// writeUSTARLongName splits a USTAR long name hdr.Name.
// name must be < 256 characters. errNameTooLong is returned
// if hdr.Name can't be split. The splitting heuristic
// is compatible with gnu tar.
func (tw *Writer) splitUSTARLongName(name string) (prefix, suffix string, err error) {
length := len(name)
if length > fileNamePrefixSize+1 {
length = fileNamePrefixSize + 1
} else if name[length-1] == '/' {
length--
}
i := strings.LastIndex(name[:length], "/")
nlen := length - i - 1
if i <= 0 || nlen > fileNameSize || nlen == 0 {
err = errNameTooLong
return
}
prefix, suffix = name[:i], name[i+1:]
return
}
// writePaxHeader writes an extended pax header to the
// archive.
func (tw *Writer) writePAXHeader(hdr *Header) error {
// Prepare extended header
ext := new(Header)
ext.Typeflag = TypeXHeader
// Setting ModTime is required for reader parsing to
// succeed, and seems harmless enough.
ext.ModTime = hdr.ModTime
// The spec asks that we namespace our pseudo files
// with the current pid.
pid := os.Getpid()
dir, file := path.Split(hdr.Name)
ext.Name = path.Join(dir,
fmt.Sprintf("PaxHeaders.%d", pid), file)[0:100]
// Construct the body
var buf bytes.Buffer
if len(hdr.Name) > fileNameSize {
fmt.Fprint(&buf, paxHeader("path="+hdr.Name))
}
if len(hdr.Linkname) > fileNameSize {
fmt.Fprint(&buf, paxHeader("linkpath="+hdr.Linkname))
}
ext.Size = int64(len(buf.Bytes()))
if err := tw.WriteHeader(ext); err != nil {
return err
}
if _, err := tw.Write(buf.Bytes()); err != nil {
return err
}
if err := tw.Flush(); err != nil {
return err
}
return nil
}
// paxHeader formats a single pax record, prefixing it with the appropriate length
func paxHeader(msg string) string {
const padding = 2 // Extra padding for space and newline
size := len(msg) + padding
size += len(strconv.Itoa(size))
record := fmt.Sprintf("%d %s\n", size, msg)
if len(record) != size {
// Final adjustment if adding size increased
// the number of digits in size
size = len(record)
record = fmt.Sprintf("%d %s\n", size, msg)
}
return record
}
// Write writes to the current entry in the tar archive.
// Write returns the error ErrWriteTooLong if more than
// hdr.Size bytes are written after WriteHeader.
......
......@@ -9,6 +9,7 @@ import (
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"testing"
"testing/iotest"
......@@ -101,6 +102,27 @@ var writerTests = []*writerTest{
},
},
},
// This file was produced using gnu tar 1.17
// gnutar -b 4 --format=ustar (longname/)*15 + file.txt
{
file: "testdata/ustar.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: strings.Repeat("longname/", 15) + "file.txt",
Mode: 0644,
Uid: 0765,
Gid: 024,
Size: 06,
ModTime: time.Unix(1360135598, 0),
Typeflag: '0',
Uname: "shane",
Gname: "staff",
},
contents: "hello\n",
},
},
},
}
// Render byte array in a two-character hexadecimal string, spaced for easy visual inspection.
......@@ -180,3 +202,61 @@ testLoop:
}
}
}
func TestPax(t *testing.T) {
// Create an archive with a large name
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
// Force a PAX long name to be written
longName := strings.Repeat("ab", 100)
contents := strings.Repeat(" ", int(hdr.Size))
hdr.Name = longName
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != longName {
t.Fatal("Couldn't recover long file name")
}
}
func TestPAXHeader(t *testing.T) {
medName := strings.Repeat("CD", 50)
longName := strings.Repeat("AB", 100)
paxTests := [][2]string{
{"name=/etc/hosts", "19 name=/etc/hosts\n"},
{"a=b", "6 a=b\n"}, // Single digit length
{"a=names", "11 a=names\n"}, // Test case involving carries
{"name=" + longName, fmt.Sprintf("210 name=%s\n", longName)},
{"name=" + medName, fmt.Sprintf("110 name=%s\n", medName)}}
for _, test := range paxTests {
key, expected := test[0], test[1]
if result := paxHeader(key); result != expected {
t.Fatalf("paxHeader: got %s, expected %s", result, expected)
}
}
}
......@@ -353,6 +353,11 @@ func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err error)
if err != nil {
return nil, err
}
// Make sure directoryOffset points to somewhere in our file.
if o := int64(d.directoryOffset); o < 0 || o >= size {
return nil, ErrFormat
}
return d, nil
}
......@@ -407,7 +412,7 @@ func findSignatureInBlock(b []byte) int {
if b[i] == 'P' && b[i+1] == 'K' && b[i+2] == 0x05 && b[i+3] == 0x06 {
// n is length of comment
n := int(b[i+directoryEndLen-2]) | int(b[i+directoryEndLen-1])<<8
if n+directoryEndLen+i == len(b) {
if n+directoryEndLen+i <= len(b) {
return i
}
}
......
......@@ -64,6 +64,24 @@ var tests = []ZipTest{
},
},
{
Name: "test-trailing-junk.zip",
Comment: "This is a zipfile comment.",
File: []ZipTestFile{
{
Name: "test.txt",
Content: []byte("This is a test text file.\n"),
Mtime: "09-05-10 12:12:02",
Mode: 0644,
},
{
Name: "gophercolor16x16.png",
File: "gophercolor16x16.png",
Mtime: "09-05-10 15:52:58",
Mode: 0644,
},
},
},
{
Name: "r.zip",
Source: returnRecursiveZip,
File: []ZipTestFile{
......@@ -262,7 +280,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
}
}
if err != zt.Error {
t.Errorf("error=%v, want %v", err, zt.Error)
t.Errorf("%s: error=%v, want %v", zt.Name, err, zt.Error)
return
}
......
......@@ -64,8 +64,15 @@ const (
zip64ExtraId = 0x0001 // zip64 Extended Information Extra Field
)
// FileHeader describes a file within a zip file.
// See the zip spec for details.
type FileHeader struct {
Name string
// Name is the name of the file.
// It must be a relative path: it must not start with a drive
// letter (e.g. C:) or leading slash, and only forward slashes
// are allowed.
Name string
CreatorVersion uint16
ReaderVersion uint16
Flags uint16
......
......@@ -163,6 +163,9 @@ func (w *Writer) Close() error {
// Create adds a file to the zip file using the provided name.
// It returns a Writer to which the file contents should be written.
// The name must be a relative path: it must not start with a drive
// letter (e.g. C:) or leading slash, and only forward slashes are
// allowed.
// The file's contents must be written to the io.Writer before the next
// call to Create, CreateHeader, or Close.
func (w *Writer) Create(name string) (io.Writer, error) {
......
......@@ -274,11 +274,10 @@ func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
return b.buf, ErrBufferFull
}
}
panic("not reached")
}
// ReadLine is a low-level line-reading primitive. Most callers should use
// ReadBytes('\n') or ReadString('\n') instead.
// ReadBytes('\n') or ReadString('\n') instead or use a Scanner.
//
// ReadLine tries to return a single line, not including the end-of-line bytes.
// If the line was too long for the buffer then isPrefix is set and the
......@@ -331,6 +330,7 @@ func (b *Reader) ReadLine() (line []byte, isPrefix bool, err error) {
// it returns the data read before the error and the error itself (often io.EOF).
// ReadBytes returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadBytes(delim byte) (line []byte, err error) {
// Use ReadSlice to look for array,
// accumulating full buffers.
......@@ -378,6 +378,7 @@ func (b *Reader) ReadBytes(delim byte) (line []byte, err error) {
// it returns the data read before the error and the error itself (often io.EOF).
// ReadString returns err != nil if and only if the returned data does not end in
// delim.
// For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadString(delim byte) (line string, err error) {
bytes, err := b.ReadBytes(delim)
return string(bytes), err
......
......@@ -7,6 +7,7 @@ package bufio_test
import (
. "bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
......@@ -434,9 +435,12 @@ func TestWriteErrors(t *testing.T) {
t.Errorf("Write hello to %v: %v", w, e)
continue
}
e = buf.Flush()
if e != w.expect {
t.Errorf("Flush %v: got %v, wanted %v", w, e, w.expect)
// Two flushes, to verify the error is sticky.
for i := 0; i < 2; i++ {
e = buf.Flush()
if e != w.expect {
t.Errorf("Flush %d/2 %v: got %v, wanted %v", i+1, w, e, w.expect)
}
}
}
}
......@@ -953,7 +957,7 @@ func TestNegativeRead(t *testing.T) {
t.Fatal("read did not panic")
case error:
if !strings.Contains(err.Error(), "reader returned negative count from Read") {
t.Fatal("wrong panic: %v", err)
t.Fatalf("wrong panic: %v", err)
}
default:
t.Fatalf("unexpected panic value: %T(%v)", err, err)
......@@ -962,6 +966,43 @@ func TestNegativeRead(t *testing.T) {
b.Read(make([]byte, 100))
}
var errFake = errors.New("fake error")
type errorThenGoodReader struct {
didErr bool
nread int
}
func (r *errorThenGoodReader) Read(p []byte) (int, error) {
r.nread++
if !r.didErr {
r.didErr = true
return 0, errFake
}
return len(p), nil
}
func TestReaderClearError(t *testing.T) {
r := &errorThenGoodReader{}
b := NewReader(r)
buf := make([]byte, 1)
if _, err := b.Read(nil); err != nil {
t.Fatalf("1st nil Read = %v; want nil", err)
}
if _, err := b.Read(buf); err != errFake {
t.Fatalf("1st Read = %v; want errFake", err)
}
if _, err := b.Read(nil); err != nil {
t.Fatalf("2nd nil Read = %v; want nil", err)
}
if _, err := b.Read(buf); err != nil {
t.Fatalf("3rd Read with buffer = %v; want nil", err)
}
if r.nread != 2 {
t.Errorf("num reads = %d; want 2", r.nread)
}
}
// An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have.
type onlyReader struct {
r io.Reader
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bufio_test
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
)
// The simplest use of a Scanner, to read standard input as a set of lines.
func ExampleScanner_lines() {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
fmt.Println(scanner.Text()) // Println will add back the final '\n'
}
if err := scanner.Err(); err != nil {
fmt.Fprintln(os.Stderr, "reading standard input:", err)
}
}
// Use a Scanner to implement a simple word-count utility by scanning the
// input as a sequence of space-delimited tokens.
func ExampleScanner_words() {
// An artificial input source.
const input = "Now is the winter of our discontent,\nMade glorious summer by this sun of York.\n"
scanner := bufio.NewScanner(strings.NewReader(input))
// Set the split function for the scanning operation.
scanner.Split(bufio.ScanWords)
// Count the words.
count := 0
for scanner.Scan() {
count++
}
if err := scanner.Err(); err != nil {
fmt.Fprintln(os.Stderr, "reading input:", err)
}
fmt.Printf("%d\n", count)
// Output: 15
}
// Use a Scanner with a custom split function (built by wrapping ScanWords) to validate
// 32-bit decimal input.
func ExampleScanner_custom() {
// An artificial input source.
const input = "1234 5678 1234567901234567890"
scanner := bufio.NewScanner(strings.NewReader(input))
// Create a custom split function by wrapping the existing ScanWords function.
split := func(data []byte, atEOF bool) (advance int, token []byte, err error) {
advance, token, err = bufio.ScanWords(data, atEOF)
if err == nil && token != nil {
_, err = strconv.ParseInt(string(token), 10, 32)
}
return
}
// Set the split function for the scanning operation.
scanner.Split(split)
// Validate the input
for scanner.Scan() {
fmt.Printf("%s\n", scanner.Text())
}
if err := scanner.Err(); err != nil {
fmt.Printf("Invalid input: %s", err)
}
// Output:
// 1234
// 5678
// Invalid input: strconv.ParseInt: parsing "1234567901234567890": value out of range
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bufio
// Exported for testing only.
import (
"unicode/utf8"
)
var IsSpace = isSpace
func (s *Scanner) MaxTokenSize(n int) {
if n < utf8.UTFMax || n > 1e9 {
panic("bad max token size")
}
if n < len(s.buf) {
s.buf = make([]byte, n)
}
s.maxTokenSize = n
}
// ErrOrEOF is like Err, but returns EOF. Used to test a corner case.
func (s *Scanner) ErrOrEOF() error {
return s.err
}
......@@ -13,6 +13,12 @@ package builtin
// bool is the set of boolean values, true and false.
type bool bool
// true and false are the two untyped boolean values.
const (
true = 0 == 0 // Untyped bool.
false = 0 != 0 // Untyped bool.
)
// uint8 is the set of all unsigned 8-bit integers.
// Range: 0 through 255.
type uint8 uint8
......@@ -85,6 +91,15 @@ type byte byte
// used, by convention, to distinguish character values from integer values.
type rune rune
// iota is a predeclared identifier representing the untyped integer ordinal
// number of the current const specification in a (usually parenthesized)
// const declaration. It is zero-indexed.
const iota = 0 // Untyped int.
// nil is a predeclared identifier representing the zero value for a
// pointer, channel, func, interface, map, or slice type.
var nil Type // Type must be a pointer, channel, func, interface, map, or slice type
// Type is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
......@@ -114,6 +129,8 @@ type ComplexType complex64
// result of append, often in the variable holding the slice itself:
// slice = append(slice, elem1, elem2)
// slice = append(slice, anotherSlice...)
// As a special case, it is legal to append a string to a byte slice, like this:
// slice = append([]byte("hello "), "world"...)
func append(slice []Type, elems ...Type) []Type
// The copy built-in function copies elements from a source slice into a
......
......@@ -87,6 +87,13 @@ func (b *Buffer) grow(n int) int {
var buf []byte
if b.buf == nil && n <= len(b.bootstrap) {
buf = b.bootstrap[0:]
} else if m+n <= cap(b.buf)/2 {
// We can slide things down instead of allocating a new
// slice. We only need m+n <= cap(b.buf) to slide, but
// we instead let capacity get twice as large so we
// don't spend all our time copying.
copy(b.buf[:], b.buf[b.off:])
buf = b.buf[:m]
} else {
// not enough space anywhere
buf = makeSlice(2*cap(b.buf) + n)
......@@ -112,20 +119,18 @@ func (b *Buffer) Grow(n int) {
b.buf = b.buf[0:m]
}
// Write appends the contents of p to the buffer. The return
// value n is the length of p; err is always nil.
// If the buffer becomes too large, Write will panic with
// ErrTooLarge.
// Write appends the contents of p to the buffer, growing the buffer as
// needed. The return value n is the length of p; err is always nil. If the
// buffer becomes too large, Write will panic with ErrTooLarge.
func (b *Buffer) Write(p []byte) (n int, err error) {
b.lastRead = opInvalid
m := b.grow(len(p))
return copy(b.buf[m:], p), nil
}
// WriteString appends the contents of s to the buffer. The return
// value n is the length of s; err is always nil.
// If the buffer becomes too large, WriteString will panic with
// ErrTooLarge.
// WriteString appends the contents of s to the buffer, growing the buffer as
// needed. The return value n is the length of s; err is always nil. If the
// buffer becomes too large, WriteString will panic with ErrTooLarge.
func (b *Buffer) WriteString(s string) (n int, err error) {
b.lastRead = opInvalid
m := b.grow(len(s))
......@@ -138,12 +143,10 @@ func (b *Buffer) WriteString(s string) (n int, err error) {
// underlying buffer.
const MinRead = 512
// ReadFrom reads data from r until EOF and appends it to the buffer.
// The return value n is the number of bytes read.
// Any error except io.EOF encountered during the read
// is also returned.
// If the buffer becomes too large, ReadFrom will panic with
// ErrTooLarge.
// ReadFrom reads data from r until EOF and appends it to the buffer, growing
// the buffer as needed. The return value n is the number of bytes read. Any
// error except io.EOF encountered during the read is also returned. If the
// buffer becomes too large, ReadFrom will panic with ErrTooLarge.
func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) {
b.lastRead = opInvalid
// If buffer is empty, reset to recover space.
......@@ -188,10 +191,10 @@ func makeSlice(n int) []byte {
return make([]byte, n)
}
// WriteTo writes data to w until the buffer is drained or an error
// occurs. The return value n is the number of bytes written; it always
// fits into an int, but it is int64 to match the io.WriterTo interface.
// Any error encountered during the write is also returned.
// WriteTo writes data to w until the buffer is drained or an error occurs.
// The return value n is the number of bytes written; it always fits into an
// int, but it is int64 to match the io.WriterTo interface. Any error
// encountered during the write is also returned.
func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) {
b.lastRead = opInvalid
if b.off < len(b.buf) {
......@@ -216,10 +219,9 @@ func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) {
return
}
// WriteByte appends the byte c to the buffer.
// The returned error is always nil, but is included
// to match bufio.Writer's WriteByte.
// If the buffer becomes too large, WriteByte will panic with
// WriteByte appends the byte c to the buffer, growing the buffer as needed.
// The returned error is always nil, but is included to match bufio.Writer's
// WriteByte. If the buffer becomes too large, WriteByte will panic with
// ErrTooLarge.
func (b *Buffer) WriteByte(c byte) error {
b.lastRead = opInvalid
......@@ -228,12 +230,10 @@ func (b *Buffer) WriteByte(c byte) error {
return nil
}
// WriteRune appends the UTF-8 encoding of Unicode
// code point r to the buffer, returning its length and
// an error, which is always nil but is included
// to match bufio.Writer's WriteRune.
// If the buffer becomes too large, WriteRune will panic with
// ErrTooLarge.
// WriteRune appends the UTF-8 encoding of Unicode code point r to the
// buffer, returning its length and an error, which is always nil but is
// included to match bufio.Writer's WriteRune. The buffer is grown as needed;
// if it becomes too large, WriteRune will panic with ErrTooLarge.
func (b *Buffer) WriteRune(r rune) (n int, err error) {
if r < utf8.RuneSelf {
b.WriteByte(byte(r))
......
......@@ -475,3 +475,53 @@ func TestUnreadByte(t *testing.T) {
t.Errorf("ReadByte = %q; want %q", c, 'm')
}
}
// Tests that we occasionally compact. Issue 5154.
func TestBufferGrowth(t *testing.T) {
var b Buffer
buf := make([]byte, 1024)
b.Write(buf[0:1])
var cap0 int
for i := 0; i < 5<<10; i++ {
b.Write(buf)
b.Read(buf)
if i == 0 {
cap0 = b.Cap()
}
}
cap1 := b.Cap()
// (*Buffer).grow allows for 2x capacity slop before sliding,
// so set our error threshold at 3x.
if cap1 > cap0*3 {
t.Errorf("buffer cap = %d; too big (grew from %d)", cap1, cap0)
}
}
// From Issue 5154.
func BenchmarkBufferNotEmptyWriteRead(b *testing.B) {
buf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
var b Buffer
b.Write(buf[0:1])
for i := 0; i < 5<<10; i++ {
b.Write(buf)
b.Read(buf)
}
}
}
// Check that we don't compact too often. From Issue 5154.
func BenchmarkBufferFullSmallReads(b *testing.B) {
buf := make([]byte, 1024)
for i := 0; i < b.N; i++ {
var b Buffer
b.Write(buf)
for b.Len()+20 < b.Cap() {
b.Write(buf[:10])
}
for i := 0; i < 5<<10; i++ {
b.Read(buf[:1])
b.Write(buf[:1])
}
}
}
......@@ -37,10 +37,6 @@ func Compare(a, b []byte) int {
return 0
}
// Equal returns a boolean reporting whether a == b.
// A nil argument is equivalent to an empty slice.
func Equal(a, b []byte) bool
func equalPortable(a, b []byte) bool {
if len(a) != len(b) {
return false
......@@ -465,10 +461,10 @@ func isSeparator(r rune) bool {
return unicode.IsSpace(r)
}
// BUG(r): The rule Title uses for word boundaries does not handle Unicode punctuation properly.
// Title returns a copy of s with all Unicode letters that begin words
// mapped to their title case.
//
// BUG: The rule Title uses for word boundaries does not handle Unicode punctuation properly.
func Title(s []byte) []byte {
// Use a closure here to remember state.
// Hackish but effective. Depends on Map scanning in order and calling
......@@ -515,6 +511,24 @@ func TrimFunc(s []byte, f func(r rune) bool) []byte {
return TrimRightFunc(TrimLeftFunc(s, f), f)
}
// TrimPrefix returns s without the provided leading prefix string.
// If s doesn't start with prefix, s is returned unchanged.
func TrimPrefix(s, prefix []byte) []byte {
if HasPrefix(s, prefix) {
return s[len(prefix):]
}
return s
}
// TrimSuffix returns s without the provided trailing suffix string.
// If s doesn't end with suffix, s is returned unchanged.
func TrimSuffix(s, suffix []byte) []byte {
if HasSuffix(s, suffix) {
return s[:len(s)-len(suffix)]
}
return s
}
// IndexFunc interprets s as a sequence of UTF-8-encoded Unicode code points.
// It returns the byte index in s of the first Unicode
// code point satisfying f(c), or -1 if none do.
......@@ -553,7 +567,10 @@ func indexFunc(s []byte, f func(r rune) bool, truth bool) int {
// inverted.
func lastIndexFunc(s []byte, f func(r rune) bool, truth bool) int {
for i := len(s); i > 0; {
r, size := utf8.DecodeLastRune(s[0:i])
r, size := rune(s[i-1]), 1
if r >= utf8.RuneSelf {
r, size = utf8.DecodeLastRune(s[0:i])
}
i -= size
if f(r) == truth {
return i
......
......@@ -4,5 +4,13 @@
package bytes
//go:noescape
// IndexByte returns the index of the first instance of c in s, or -1 if c is not present in s.
func IndexByte(s []byte, c byte) int // asm_$GOARCH.s
//go:noescape
// Equal returns a boolean reporting whether a == b.
// A nil argument is equivalent to an empty slice.
func Equal(a, b []byte) bool // asm_arm.s or ../runtime/asm_{386,amd64}.s
......@@ -61,6 +61,10 @@ var compareTests = []struct {
{[]byte("ab"), []byte("x"), -1},
{[]byte("x"), []byte("a"), 1},
{[]byte("b"), []byte("x"), -1},
// test runtime·memeq's chunked implementation
{[]byte("abcdefgh"), []byte("abcdefgh"), 0},
{[]byte("abcdefghi"), []byte("abcdefghi"), 0},
{[]byte("abcdefghi"), []byte("abcdefghj"), -1},
// nil tests
{nil, nil, 0},
{[]byte(""), nil, 0},
......@@ -86,6 +90,58 @@ func TestCompare(t *testing.T) {
}
}
func TestEqual(t *testing.T) {
var size = 128
if testing.Short() {
size = 32
}
a := make([]byte, size)
b := make([]byte, size)
b_init := make([]byte, size)
// randomish but deterministic data
for i := 0; i < size; i++ {
a[i] = byte(17 * i)
b_init[i] = byte(23*i + 100)
}
for len := 0; len <= size; len++ {
for x := 0; x <= size-len; x++ {
for y := 0; y <= size-len; y++ {
copy(b, b_init)
copy(b[y:y+len], a[x:x+len])
if !Equal(a[x:x+len], b[y:y+len]) || !Equal(b[y:y+len], a[x:x+len]) {
t.Errorf("Equal(%d, %d, %d) = false", len, x, y)
}
}
}
}
}
// make sure Equal returns false for minimally different strings. The data
// is all zeros except for a single one in one location.
func TestNotEqual(t *testing.T) {
var size = 128
if testing.Short() {
size = 32
}
a := make([]byte, size)
b := make([]byte, size)
for len := 0; len <= size; len++ {
for x := 0; x <= size-len; x++ {
for y := 0; y <= size-len; y++ {
for diffpos := x; diffpos < x+len; diffpos++ {
a[diffpos] = 1
if Equal(a[x:x+len], b[y:y+len]) || Equal(b[y:y+len], a[x:x+len]) {
t.Errorf("NotEqual(%d, %d, %d, %d) = true", len, x, y, diffpos)
}
a[diffpos] = 0
}
}
}
}
}
var indexTests = []BinOpTest{
{"", "", 0},
{"", "a", -1},
......@@ -303,10 +359,30 @@ func bmIndexByte(b *testing.B, index func([]byte, byte) int, n int) {
buf[n-1] = '\x00'
}
func BenchmarkEqual0(b *testing.B) {
var buf [4]byte
buf1 := buf[0:0]
buf2 := buf[1:1]
for i := 0; i < b.N; i++ {
eq := Equal(buf1, buf2)
if !eq {
b.Fatal("bad equal")
}
}
}
func BenchmarkEqual1(b *testing.B) { bmEqual(b, Equal, 1) }
func BenchmarkEqual6(b *testing.B) { bmEqual(b, Equal, 6) }
func BenchmarkEqual9(b *testing.B) { bmEqual(b, Equal, 9) }
func BenchmarkEqual15(b *testing.B) { bmEqual(b, Equal, 15) }
func BenchmarkEqual16(b *testing.B) { bmEqual(b, Equal, 16) }
func BenchmarkEqual20(b *testing.B) { bmEqual(b, Equal, 20) }
func BenchmarkEqual32(b *testing.B) { bmEqual(b, Equal, 32) }
func BenchmarkEqual4K(b *testing.B) { bmEqual(b, Equal, 4<<10) }
func BenchmarkEqual4M(b *testing.B) { bmEqual(b, Equal, 4<<20) }
func BenchmarkEqual64M(b *testing.B) { bmEqual(b, Equal, 64<<20) }
func BenchmarkEqualPort1(b *testing.B) { bmEqual(b, EqualPortable, 1) }
func BenchmarkEqualPort6(b *testing.B) { bmEqual(b, EqualPortable, 6) }
func BenchmarkEqualPort32(b *testing.B) { bmEqual(b, EqualPortable, 32) }
func BenchmarkEqualPort4K(b *testing.B) { bmEqual(b, EqualPortable, 4<<10) }
func BenchmarkEqualPortable4M(b *testing.B) { bmEqual(b, EqualPortable, 4<<20) }
......@@ -794,8 +870,8 @@ func TestRunes(t *testing.T) {
}
type TrimTest struct {
f string
in, cutset, out string
f string
in, arg, out string
}
var trimTests = []TrimTest{
......@@ -820,12 +896,17 @@ var trimTests = []TrimTest{
{"TrimRight", "", "123", ""},
{"TrimRight", "", "", ""},
{"TrimRight", "☺\xc0", "☺", "☺\xc0"},
{"TrimPrefix", "aabb", "a", "abb"},
{"TrimPrefix", "aabb", "b", "aabb"},
{"TrimSuffix", "aabb", "a", "aabb"},
{"TrimSuffix", "aabb", "b", "aab"},
}
func TestTrim(t *testing.T) {
for _, tc := range trimTests {
name := tc.f
var f func([]byte, string) []byte
var fb func([]byte, []byte) []byte
switch name {
case "Trim":
f = Trim
......@@ -833,12 +914,21 @@ func TestTrim(t *testing.T) {
f = TrimLeft
case "TrimRight":
f = TrimRight
case "TrimPrefix":
fb = TrimPrefix
case "TrimSuffix":
fb = TrimSuffix
default:
t.Errorf("Undefined trim function %s", name)
}
actual := string(f([]byte(tc.in), tc.cutset))
var actual string
if f != nil {
actual = string(f([]byte(tc.in), tc.arg))
} else {
actual = string(fb([]byte(tc.in), []byte(tc.arg)))
}
if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.arg, actual, tc.out)
}
}
}
......@@ -1059,3 +1149,10 @@ func BenchmarkFieldsFunc(b *testing.B) {
FieldsFunc(fieldsInput, unicode.IsSpace)
}
}
func BenchmarkTrimSpace(b *testing.B) {
s := []byte(" Some text. \n")
for i := 0; i < b.N; i++ {
TrimSpace(s)
}
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// +build linux
package bytes_test
import (
. "bytes"
"syscall"
"testing"
"unsafe"
)
// This file tests the situation where memeq is checking
// data very near to a page boundary. We want to make sure
// equal does not read across the boundary and cause a page
// fault where it shouldn't.
// This test runs only on linux. The code being tested is
// not OS-specific, so it does not need to be tested on all
// operating systems.
func TestEqualNearPageBoundary(t *testing.T) {
pagesize := syscall.Getpagesize()
b := make([]byte, 4*pagesize)
i := pagesize
for ; uintptr(unsafe.Pointer(&b[i]))%uintptr(pagesize) != 0; i++ {
}
syscall.Mprotect(b[i-pagesize:i], 0)
syscall.Mprotect(b[i+pagesize:i+2*pagesize], 0)
defer syscall.Mprotect(b[i-pagesize:i], syscall.PROT_READ|syscall.PROT_WRITE)
defer syscall.Mprotect(b[i+pagesize:i+2*pagesize], syscall.PROT_READ|syscall.PROT_WRITE)
// both of these should fault
//pagesize += int(b[i-1])
//pagesize += int(b[i+pagesize])
for j := 0; j < pagesize; j++ {
b[i+j] = 'A'
}
for j := 0; j <= pagesize; j++ {
Equal(b[i:i+j], b[i+pagesize-j:i+pagesize])
Equal(b[i+pagesize-j:i+pagesize], b[i:i+j])
}
}
......@@ -66,3 +66,20 @@ func ExampleCompare_search() {
// Found it!
}
}
func ExampleTrimSuffix() {
var b = []byte("Hello, goodbye, etc!")
b = bytes.TrimSuffix(b, []byte("goodbye, etc!"))
b = bytes.TrimSuffix(b, []byte("gopher"))
b = append(b, bytes.TrimSuffix([]byte("world!"), []byte("x!"))...)
os.Stdout.Write(b)
// Output: Hello, world!
}
func ExampleTrimPrefix() {
var b = []byte("Goodbye,, world!")
b = bytes.TrimPrefix(b, []byte("Goodbye,"))
b = bytes.TrimPrefix(b, []byte("See ya,"))
fmt.Printf("Hello%s", b)
// Output: Hello, world!
}
......@@ -7,3 +7,7 @@ package bytes
// Export func for testing
var IndexBytePortable = indexBytePortable
var EqualPortable = equalPortable
func (b *Buffer) Cap() int {
return cap(b.buf)
}
......@@ -54,8 +54,6 @@ func (t huffmanTree) Decode(br *bitReader) (v uint16) {
nodeIndex = node.right
}
}
panic("unreachable")
}
// newHuffmanTree builds a Huffman tree from a slice containing the code
......
......@@ -158,7 +158,6 @@ func (b *syncBuffer) Read(p []byte) (n int, err error) {
}
<-b.ready
}
panic("unreachable")
}
func (b *syncBuffer) signal() {
......
......@@ -263,7 +263,6 @@ func (f *decompressor) Read(b []byte) (int, error) {
}
f.step(f)
}
panic("unreachable")
}
func (f *decompressor) Close() error {
......@@ -495,7 +494,6 @@ func (f *decompressor) huffmanBlock() {
return
}
}
panic("unreached")
}
// copyHist copies f.copyLen bytes from f.hist (f.copyDist bytes ago) to itself.
......@@ -642,7 +640,6 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
return int(chunk >> huffmanValueShift), nil
}
}
return 0, CorruptInputError(f.roffset)
}
// Flush any buffered output to the underlying writer.
......
......@@ -99,5 +99,4 @@ func offsetCode(off uint32) uint32 {
default:
return offsetCodes[off>>14] + 28
}
panic("unreachable")
}
......@@ -120,7 +120,6 @@ func (z *Reader) readString() (string, error) {
return string(z.buf[0:i]), nil
}
}
panic("not reached")
}
func (z *Reader) read2() (uint32, error) {
......
......@@ -28,7 +28,7 @@ type Writer struct {
Header
w io.Writer
level int
compressor io.WriteCloser
compressor *flate.Writer
digest hash.Hash32
size uint32
closed bool
......@@ -191,6 +191,28 @@ func (z *Writer) Write(p []byte) (int, error) {
return n, z.err
}
// Flush flushes any pending compressed data to the underlying writer.
//
// It is useful mainly in compressed network protocols, to ensure that
// a remote reader has enough data to reconstruct a packet. Flush does
// not return until the data has been written. If the underlying
// writer returns an error, Flush returns that error.
//
// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
func (z *Writer) Flush() error {
if z.err != nil {
return z.err
}
if z.closed {
return nil
}
if z.compressor == nil {
z.Write(nil)
}
z.err = z.compressor.Flush()
return z.err
}
// Close closes the Writer. It does not close the underlying io.Writer.
func (z *Writer) Close() error {
if z.err != nil {
......
......@@ -157,3 +157,43 @@ func TestLatin1RoundTrip(t *testing.T) {
}
}
}
func TestWriterFlush(t *testing.T) {
buf := new(bytes.Buffer)
w := NewWriter(buf)
w.Comment = "comment"
w.Extra = []byte("extra")
w.ModTime = time.Unix(1e8, 0)
w.Name = "name"
n0 := buf.Len()
if n0 != 0 {
t.Fatalf("buffer size = %d before writes; want 0", n0)
}
if err := w.Flush(); err != nil {
t.Fatal(err)
}
n1 := buf.Len()
if n1 == 0 {
t.Fatal("no data after first flush")
}
w.Write([]byte("x"))
n2 := buf.Len()
if n1 != n2 {
t.Fatalf("after writing a single byte, size changed from %d to %d; want no change", n1, n2)
}
if err := w.Flush(); err != nil {
t.Fatal(err)
}
n3 := buf.Len()
if n2 == n3 {
t.Fatal("Flush didn't flush any data")
}
}
......@@ -121,7 +121,6 @@ func (d *decoder) Read(b []byte) (int, error) {
}
d.decode()
}
panic("unreachable")
}
// decode decompresses bytes from r and leaves them in d.toRead.
......@@ -203,7 +202,6 @@ func (d *decoder) decode() {
return
}
}
panic("unreachable")
}
func (d *decoder) flush() {
......
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This example demonstrates a priority queue built using the heap interface.
package heap_test
import (
"container/heap"
"fmt"
)
// An Item is something we manage in a priority queue.
type Item struct {
value string // The value of the item; arbitrary.
priority int // The priority of the item in the queue.
// The index is needed by changePriority and is maintained by the heap.Interface methods.
index int // The index of the item in the heap.
}
// A PriorityQueue implements heap.Interface and holds Items.
type PriorityQueue []*Item
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
// We want Pop to give us the highest, not lowest, priority so we use greater than here.
return pq[i].priority > pq[j].priority
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *PriorityQueue) Push(x interface{}) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
n := len(*pq)
item := x.(*Item)
item.index = n
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
a := *pq
n := len(a)
item := a[n-1]
item.index = -1 // for safety
*pq = a[0 : n-1]
return item
}
// update is not used by the example but shows how to take the top item from
// the queue, update its priority and value, and put it back.
func (pq *PriorityQueue) update(value string, priority int) {
item := heap.Pop(pq).(*Item)
item.value = value
item.priority = priority
heap.Push(pq, item)
}
// changePriority is not used by the example but shows how to change the
// priority of an arbitrary item.
func (pq *PriorityQueue) changePriority(item *Item, priority int) {
heap.Remove(pq, item.index)
item.priority = priority
heap.Push(pq, item)
}
// This example pushes 10 items into a PriorityQueue and takes them out in
// order of priority.
func Example() {
const nItem = 10
// Random priorities for the items (a permutation of 0..9, times 11)).
priorities := [nItem]int{
77, 22, 44, 55, 11, 88, 33, 99, 00, 66,
}
values := [nItem]string{
"zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
}
// Create a priority queue and put some items in it.
pq := make(PriorityQueue, 0, nItem)
for i := 0; i < cap(pq); i++ {
item := &Item{
value: values[i],
priority: priorities[i],
}
heap.Push(&pq, item)
}
// Take the items out; should arrive in decreasing priority order.
// For example, the highest priority (99) is the seventh item, so output starts with 99:"seven".
for i := 0; i < nItem; i++ {
item := heap.Pop(&pq).(*Item)
fmt.Printf("%.2d:%s ", item.priority, item.value)
}
// Output:
// 99:seven 88:five 77:zero 66:nine 55:three 44:two 33:six 22:one 11:four 00:eight
}
......@@ -4,13 +4,13 @@
// Package heap provides heap operations for any type that implements
// heap.Interface. A heap is a tree with the property that each node is the
// highest-valued node in its subtree.
// minimum-valued node in its subtree.
//
// A heap is a common way to implement a priority queue. To build a priority
// queue, implement the Heap interface with the (negative) priority as the
// ordering for the Less method, so Push adds items while Pop removes the
// highest-priority item from the queue. The Examples include such an
// implementation; the file example_test.go has the complete source.
// implementation; the file example_pq_test.go has the complete source.
//
package heap
......@@ -90,7 +90,7 @@ func up(h Interface, j int) {
func down(h Interface, i, n int) {
for {
j1 := 2*i + 1
if j1 >= n {
if j1 >= n || j1 < 0 { // j1 < 0 after int overflow
break
}
j := j1 // left child
......
......@@ -2,10 +2,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package heap_test
package heap
import (
. "container/heap"
"testing"
)
......
......@@ -42,6 +42,12 @@ func NewCBCEncrypter(b Block, iv []byte) BlockMode {
func (x *cbcEncrypter) BlockSize() int { return x.blockSize }
func (x *cbcEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
for i := 0; i < x.blockSize; i++ {
x.iv[i] ^= src[i]
......@@ -70,6 +76,12 @@ func NewCBCDecrypter(b Block, iv []byte) BlockMode {
func (x *cbcDecrypter) BlockSize() int { return x.blockSize }
func (x *cbcDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Decrypt(x.tmp, src[:x.blockSize])
for i := 0; i < x.blockSize; i++ {
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher_test
import (
"crypto/aes"
"crypto/cipher"
"testing"
)
func TestCryptBlocks(t *testing.T) {
buf := make([]byte, 16)
block, _ := aes.NewCipher(buf)
mode := cipher.NewCBCDecrypter(block, buf)
mustPanic(t, "crypto/cipher: input not full blocks", func() { mode.CryptBlocks(buf, buf[:3]) })
mustPanic(t, "crypto/cipher: output smaller than input", func() { mode.CryptBlocks(buf[:3], buf) })
mode = cipher.NewCBCEncrypter(block, buf)
mustPanic(t, "crypto/cipher: input not full blocks", func() { mode.CryptBlocks(buf, buf[:3]) })
mustPanic(t, "crypto/cipher: output smaller than input", func() { mode.CryptBlocks(buf[:3], buf) })
}
func mustPanic(t *testing.T, msg string, f func()) {
defer func() {
err := recover()
if err == nil {
t.Errorf("function did not panic, wanted %q", msg)
} else if err != msg {
t.Errorf("got panic %v, wanted %q", err, msg)
}
}()
f()
}
......@@ -233,7 +233,7 @@ func ExampleStreamReader() {
}
defer outFile.Close()
reader := &cipher.StreamReader{stream, inFile}
reader := &cipher.StreamReader{S: stream, R: inFile}
// Copy the input file to the output file, decrypting as we go.
if _, err := io.Copy(outFile, reader); err != nil {
panic(err)
......@@ -270,7 +270,7 @@ func ExampleStreamWriter() {
}
defer outFile.Close()
writer := &cipher.StreamWriter{stream, outFile, nil}
writer := &cipher.StreamWriter{S: stream, W: outFile}
// Copy the input file to the output file, encrypting as we go.
if _, err := io.Copy(writer, inFile); err != nil {
panic(err)
......
......@@ -144,8 +144,6 @@ GeneratePrimes:
params.G = g
return
}
panic("unreachable")
}
// GenerateKey generates a public&private key pair. The Parameters of the
......
......@@ -63,8 +63,9 @@ func testParameterGeneration(t *testing.T, sizes ParameterSizes, L, N int) {
}
func TestParameterGeneration(t *testing.T) {
// This test is too slow to run all the time.
return
if testing.Short() {
t.Skip("skipping parameter generation test in short mode")
}
testParameterGeneration(t, L1024N160, 1024, 160)
testParameterGeneration(t, L2048N224, 2048, 224)
......
......@@ -49,7 +49,7 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error)
return
}
// GenerateKey generates a public&private key pair.
// GenerateKey generates a public and private key pair.
func GenerateKey(c elliptic.Curve, rand io.Reader) (priv *PrivateKey, err error) {
k, err := randFieldElement(c, rand)
if err != nil {
......
......@@ -161,6 +161,11 @@ var data = Data{
}
var program = `
// DO NOT EDIT.
// Generate with: go run gen.go{{if .Full}} -full{{end}} | gofmt >md5block.go
// +build !amd64
package md5
import (
......@@ -186,6 +191,16 @@ import (
}
{{end}}
const x86 = runtime.GOARCH == "amd64" || runtime.GOARCH == "386"
var littleEndian bool
func init() {
x := uint32(0x04030201)
y := [4]byte{0x1, 0x2, 0x3, 0x4}
littleEndian = *(*[4]byte)(unsafe.Pointer(&x)) == y
}
func block(dig *digest, p []byte) {
a := dig.s[0]
b := dig.s[1]
......@@ -197,13 +212,13 @@ func block(dig *digest, p []byte) {
aa, bb, cc, dd := a, b, c, d
// This is a constant condition - it is not evaluated on each iteration.
if runtime.GOARCH == "amd64" || runtime.GOARCH == "386" {
if x86 {
// MD5 was designed so that x86 processors can just iterate
// over the block data directly as uint32s, and we generate
// less code and run 1.3x faster if we take advantage of that.
// My apologies.
X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else if uintptr(unsafe.Pointer(&p[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
} else if littleEndian && uintptr(unsafe.Pointer(&p[0]))&(unsafe.Alignof(uint32(0))-1) == 0 {
X = (*[16]uint32)(unsafe.Pointer(&p[0]))
} else {
X = &xbuf
......
......@@ -2,10 +2,9 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package md5_test
package md5
import (
"crypto/md5"
"fmt"
"io"
"testing"
......@@ -54,7 +53,7 @@ var golden = []md5Test{
func TestGolden(t *testing.T) {
for i := 0; i < len(golden); i++ {
g := golden[i]
c := md5.New()
c := New()
buf := make([]byte, len(g.in)+4)
for j := 0; j < 3+4; j++ {
if j < 2 {
......@@ -79,14 +78,14 @@ func TestGolden(t *testing.T) {
}
func ExampleNew() {
h := md5.New()
h := New()
io.WriteString(h, "The fog is getting thicker!")
io.WriteString(h, "And Leon's getting laaarger!")
fmt.Printf("%x", h.Sum(nil))
// Output: e2c569be17396eca2a2e3c11578123ed
}
var bench = md5.New()
var bench = New()
var buf = make([]byte, 8192+1)
var sum = make([]byte, bench.Size())
......
// DO NOT EDIT.
// Generate with: go run gen.go -full | gofmt >md5block.go
// +build !amd64,!386
package md5
import (
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64 386
package md5
func block(dig *digest, p []byte)
......@@ -98,12 +98,13 @@ func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
return
}
}
return
}
// Int returns a uniform random value in [0, max).
// Int returns a uniform random value in [0, max). It panics if max <= 0.
func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
if max.Sign() <= 0 {
panic("crypto/rand: argument to Int is <= 0")
}
k := (max.BitLen() + 7) / 8
// b is the number of bits in the most significant byte of max.
......@@ -130,6 +131,4 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
return
}
}
return
}
......@@ -13,7 +13,7 @@ import "strconv"
// A Cipher is an instance of RC4 using a particular key.
type Cipher struct {
s [256]byte
s [256]uint32
i, j uint8
}
......@@ -32,27 +32,16 @@ func NewCipher(key []byte) (*Cipher, error) {
}
var c Cipher
for i := 0; i < 256; i++ {
c.s[i] = uint8(i)
c.s[i] = uint32(i)
}
var j uint8 = 0
for i := 0; i < 256; i++ {
j += c.s[i] + key[i%k]
j += uint8(c.s[i]) + key[i%k]
c.s[i], c.s[j] = c.s[j], c.s[i]
}
return &c, nil
}
// XORKeyStream sets dst to the result of XORing src with the key stream.
// Dst and src may be the same slice but otherwise should not overlap.
func (c *Cipher) XORKeyStream(dst, src []byte) {
for i := range src {
c.i += 1
c.j += c.s[c.i]
c.s[c.i], c.s[c.j] = c.s[c.j], c.s[c.i]
dst[i] = src[i] ^ c.s[c.s[c.i]+c.s[c.j]]
}
}
// Reset zeros the key data so that it will no longer appear in the
// process's memory.
func (c *Cipher) Reset() {
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64 arm 386
package rc4
func xorKeyStream(dst, src *byte, n int, state *[256]uint32, i, j *uint8)
// XORKeyStream sets dst to the result of XORing src with the key stream.
// Dst and src may be the same slice but otherwise should not overlap.
func (c *Cipher) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
xorKeyStream(&dst[0], &src[0], len(src), &c.s, &c.i, &c.j)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64,!arm,!386
package rc4
// XORKeyStream sets dst to the result of XORing src with the key stream.
// Dst and src may be the same slice but otherwise should not overlap.
func (c *Cipher) XORKeyStream(dst, src []byte) {
i, j := c.i, c.j
for k, v := range src {
i += 1
j += uint8(c.s[i])
c.s[i], c.s[j] = c.s[j], c.s[i]
dst[k] = v ^ byte(c.s[byte(c.s[i]+c.s[j])])
}
c.i, c.j = i, j
}
......@@ -5,6 +5,8 @@
package rc4
import (
"bytes"
"fmt"
"testing"
)
......@@ -37,23 +39,124 @@ var golden = []rc4Test{
[]byte{0x57, 0x69, 0x6b, 0x69},
[]byte{0x60, 0x44, 0xdb, 0x6d, 0x41, 0xb7},
},
{
[]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
[]byte{
0xde, 0x18, 0x89, 0x41, 0xa3, 0x37, 0x5d, 0x3a,
0x8a, 0x06, 0x1e, 0x67, 0x57, 0x6e, 0x92, 0x6d,
0xc7, 0x1a, 0x7f, 0xa3, 0xf0, 0xcc, 0xeb, 0x97,
0x45, 0x2b, 0x4d, 0x32, 0x27, 0x96, 0x5f, 0x9e,
0xa8, 0xcc, 0x75, 0x07, 0x6d, 0x9f, 0xb9, 0xc5,
0x41, 0x7a, 0xa5, 0xcb, 0x30, 0xfc, 0x22, 0x19,
0x8b, 0x34, 0x98, 0x2d, 0xbb, 0x62, 0x9e, 0xc0,
0x4b, 0x4f, 0x8b, 0x05, 0xa0, 0x71, 0x08, 0x50,
0x92, 0xa0, 0xc3, 0x58, 0x4a, 0x48, 0xe4, 0xa3,
0x0a, 0x39, 0x7b, 0x8a, 0xcd, 0x1d, 0x00, 0x9e,
0xc8, 0x7d, 0x68, 0x11, 0xf2, 0x2c, 0xf4, 0x9c,
0xa3, 0xe5, 0x93, 0x54, 0xb9, 0x45, 0x15, 0x35,
0xa2, 0x18, 0x7a, 0x86, 0x42, 0x6c, 0xca, 0x7d,
0x5e, 0x82, 0x3e, 0xba, 0x00, 0x44, 0x12, 0x67,
0x12, 0x57, 0xb8, 0xd8, 0x60, 0xae, 0x4c, 0xbd,
0x4c, 0x49, 0x06, 0xbb, 0xc5, 0x35, 0xef, 0xe1,
0x58, 0x7f, 0x08, 0xdb, 0x33, 0x95, 0x5c, 0xdb,
0xcb, 0xad, 0x9b, 0x10, 0xf5, 0x3f, 0xc4, 0xe5,
0x2c, 0x59, 0x15, 0x65, 0x51, 0x84, 0x87, 0xfe,
0x08, 0x4d, 0x0e, 0x3f, 0x03, 0xde, 0xbc, 0xc9,
0xda, 0x1c, 0xe9, 0x0d, 0x08, 0x5c, 0x2d, 0x8a,
0x19, 0xd8, 0x37, 0x30, 0x86, 0x16, 0x36, 0x92,
0x14, 0x2b, 0xd8, 0xfc, 0x5d, 0x7a, 0x73, 0x49,
0x6a, 0x8e, 0x59, 0xee, 0x7e, 0xcf, 0x6b, 0x94,
0x06, 0x63, 0xf4, 0xa6, 0xbe, 0xe6, 0x5b, 0xd2,
0xc8, 0x5c, 0x46, 0x98, 0x6c, 0x1b, 0xef, 0x34,
0x90, 0xd3, 0x7b, 0x38, 0xda, 0x85, 0xd3, 0x2e,
0x97, 0x39, 0xcb, 0x23, 0x4a, 0x2b, 0xe7, 0x40,
},
},
}
func testEncrypt(t *testing.T, desc string, c *Cipher, src, expect []byte) {
dst := make([]byte, len(src))
c.XORKeyStream(dst, src)
for i, v := range dst {
if v != expect[i] {
t.Fatalf("%s: mismatch at byte %d:\nhave %x\nwant %x", desc, i, dst, expect)
}
}
}
func TestGolden(t *testing.T) {
for i := 0; i < len(golden); i++ {
g := golden[i]
c, err := NewCipher(g.key)
if err != nil {
t.Errorf("Failed to create cipher at golden index %d", i)
return
for gi, g := range golden {
data := make([]byte, len(g.keystream))
for i := range data {
data[i] = byte(i)
}
keystream := make([]byte, len(g.keystream))
c.XORKeyStream(keystream, keystream)
for j, v := range keystream {
if g.keystream[j] != v {
t.Errorf("Failed at golden index %d", i)
break
expect := make([]byte, len(g.keystream))
for i := range expect {
expect[i] = byte(i) ^ g.keystream[i]
}
for size := 1; size <= len(g.keystream); size++ {
c, err := NewCipher(g.key)
if err != nil {
t.Fatalf("#%d: NewCipher: %v", gi, err)
}
off := 0
for off < len(g.keystream) {
n := len(g.keystream) - off
if n > size {
n = size
}
desc := fmt.Sprintf("#%d@[%d:%d]", gi, off, off+n)
testEncrypt(t, desc, c, data[off:off+n], expect[off:off+n])
off += n
}
}
}
}
func TestBlock(t *testing.T) {
c1a, _ := NewCipher(golden[0].key)
c1b, _ := NewCipher(golden[1].key)
data1 := make([]byte, 1<<20)
for i := range data1 {
c1a.XORKeyStream(data1[i:i+1], data1[i:i+1])
c1b.XORKeyStream(data1[i:i+1], data1[i:i+1])
}
c2a, _ := NewCipher(golden[0].key)
c2b, _ := NewCipher(golden[1].key)
data2 := make([]byte, 1<<20)
c2a.XORKeyStream(data2, data2)
c2b.XORKeyStream(data2, data2)
if !bytes.Equal(data1, data2) {
t.Fatalf("bad block")
}
}
func benchmark(b *testing.B, size int64) {
buf := make([]byte, size)
c, err := NewCipher(golden[0].key)
if err != nil {
panic(err)
}
b.SetBytes(size)
for i := 0; i < b.N; i++ {
c.XORKeyStream(buf, buf)
}
}
func BenchmarkRC4_128(b *testing.B) {
benchmark(b, 128)
}
func BenchmarkRC4_1K(b *testing.B) {
benchmark(b, 1024)
}
func BenchmarkRC4_8K(b *testing.B) {
benchmark(b, 8096)
}
......@@ -150,6 +150,20 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (priv *Priva
NextSetOfPrimes:
for {
todo := bits
// crypto/rand should set the top two bits in each prime.
// Thus each prime has the form
// p_i = 2^bitlen(p_i) × 0.11... (in base 2).
// And the product is:
// P = 2^todo × α
// where α is the product of nprimes numbers of the form 0.11...
//
// If α < 1/2 (which can happen for nprimes > 2), we need to
// shift todo to compensate for lost bits: the mean value of 0.11...
// is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
// will give good results.
if nprimes >= 7 {
todo += (nprimes - 2) / 5
}
for i := 0; i < nprimes; i++ {
primes[i], err = rand.Prime(random, todo/(nprimes-i))
if err != nil {
......@@ -176,8 +190,9 @@ NextSetOfPrimes:
totient.Mul(totient, pminus1)
}
if n.BitLen() != bits {
// This should never happen because crypto/rand should
// set the top two bits in each prime.
// This should never happen for nprimes == 2 because
// crypto/rand should set the top two bits in each prime.
// For nprimes > 2 we hope it does not happen often.
continue NextSetOfPrimes
}
......@@ -188,7 +203,9 @@ NextSetOfPrimes:
g.GCD(priv.D, y, e, totient)
if g.Cmp(bigOne) == 0 {
priv.D.Add(priv.D, totient)
if priv.D.Sign() < 0 {
priv.D.Add(priv.D, totient)
}
priv.Primes = primes
priv.N = n
......
......@@ -28,11 +28,11 @@ func TestKeyGeneration(t *testing.T) {
}
func Test3PrimeKeyGeneration(t *testing.T) {
size := 768
if testing.Short() {
return
size = 256
}
size := 768
priv, err := GenerateMultiPrimeKey(rand.Reader, 3, size)
if err != nil {
t.Errorf("failed to generate key")
......@@ -41,11 +41,11 @@ func Test3PrimeKeyGeneration(t *testing.T) {
}
func Test4PrimeKeyGeneration(t *testing.T) {
size := 768
if testing.Short() {
return
size = 256
}
size := 768
priv, err := GenerateMultiPrimeKey(rand.Reader, 4, size)
if err != nil {
t.Errorf("failed to generate key")
......@@ -53,6 +53,24 @@ func Test4PrimeKeyGeneration(t *testing.T) {
testKeyBasics(t, priv)
}
func TestNPrimeKeyGeneration(t *testing.T) {
primeSize := 64
maxN := 24
if testing.Short() {
primeSize = 16
maxN = 16
}
// Test that generation of N-prime keys works for N > 4.
for n := 5; n < maxN; n++ {
priv, err := GenerateMultiPrimeKey(rand.Reader, n, 64+n*primeSize)
if err == nil {
testKeyBasics(t, priv)
} else {
t.Errorf("failed to generate %d-prime key", n)
}
}
}
func TestGnuTLSKey(t *testing.T) {
// This is a key generated by `certtool --generate-privkey --bits 128`.
// It's such that de ≢ 1 mod φ(n), but is congruent mod the order of
......@@ -75,6 +93,9 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) {
if err := priv.Validate(); err != nil {
t.Errorf("Validate() failed: %s", err)
}
if priv.D.Cmp(priv.N) > 0 {
t.Errorf("private exponent too large")
}
pub := &priv.PublicKey
m := big.NewInt(42)
......
......@@ -4,10 +4,9 @@
// SHA1 hash algorithm. See RFC 3174.
package sha1_test
package sha1
import (
"crypto/sha1"
"fmt"
"io"
"testing"
......@@ -55,7 +54,7 @@ var golden = []sha1Test{
func TestGolden(t *testing.T) {
for i := 0; i < len(golden); i++ {
g := golden[i]
c := sha1.New()
c := New()
for j := 0; j < 3; j++ {
if j < 2 {
io.WriteString(c, g.in)
......@@ -74,13 +73,13 @@ func TestGolden(t *testing.T) {
}
func ExampleNew() {
h := sha1.New()
h := New()
io.WriteString(h, "His money is twice tainted: 'taint yours and 'taint mine.")
fmt.Printf("% x", h.Sum(nil))
// Output: 59 7f 6a 54 00 10 f9 4c 15 d7 18 06 a9 9a 2c 87 10 e7 47 bd
}
var bench = sha1.New()
var bench = New()
var buf = make([]byte, 8192)
func benchmarkSize(b *testing.B, size int) {
......
......@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64,!386
// SHA1 block step.
// In its own file so that a faster assembly or C version
// can be substituted easily.
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build amd64 386
package sha1
func block(dig *digest, p []byte)
......@@ -204,7 +204,24 @@ type Config struct {
// connections using that key are compromised.
SessionTicketKey [32]byte
serverInitOnce sync.Once
serverInitOnce sync.Once // guards calling (*Config).serverInit
}
func (c *Config) serverInit() {
if c.SessionTicketsDisabled {
return
}
// If the key has already been set then we have nothing to do.
for _, b := range c.SessionTicketKey {
if b != 0 {
return
}
}
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
c.SessionTicketsDisabled = true
}
}
func (c *Config) rand() io.Reader {
......
......@@ -16,36 +16,80 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"flag"
"fmt"
"log"
"math/big"
"net"
"os"
"strings"
"time"
)
var hostName *string = flag.String("host", "127.0.0.1", "Hostname to generate a certificate for")
var (
host = flag.String("host", "", "Comma-separated hostnames and IPs to generate a certificate for")
validFrom = flag.String("start-date", "", "Creation date formatted as Jan 1 15:04:05 2011")
validFor = flag.Duration("duration", 365*24*time.Hour, "Duration that certificate is valid for")
isCA = flag.Bool("ca", false, "whether this cert should be its own Certificate Authority")
rsaBits = flag.Int("rsa-bits", 1024, "Size of RSA key to generate")
)
func main() {
flag.Parse()
priv, err := rsa.GenerateKey(rand.Reader, 1024)
if len(*host) == 0 {
log.Fatalf("Missing required --host parameter")
}
priv, err := rsa.GenerateKey(rand.Reader, *rsaBits)
if err != nil {
log.Fatalf("failed to generate private key: %s", err)
return
}
now := time.Now()
var notBefore time.Time
if len(*validFrom) == 0 {
notBefore = time.Now()
} else {
notBefore, err = time.Parse("Jan 2 15:04:05 2006", *validFrom)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to parse creation date: %s\n", err)
os.Exit(1)
}
}
notAfter := notBefore.Add(*validFor)
// end of ASN.1 time
endOfTime := time.Date(2049, 12, 31, 23, 59, 59, 0, time.UTC)
if notAfter.After(endOfTime) {
notAfter = endOfTime
}
template := x509.Certificate{
SerialNumber: new(big.Int).SetInt64(0),
Subject: pkix.Name{
CommonName: *hostName,
Organization: []string{"Acme Co"},
},
NotBefore: now.Add(-5 * time.Minute).UTC(),
NotAfter: now.AddDate(1, 0, 0).UTC(), // valid for 1 year.
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
hosts := strings.Split(*host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
SubjectKeyId: []byte{1, 2, 3, 4},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
if *isCA {
template.IsCA = true
template.KeyUsage |= x509.KeyUsageCertSign
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
......
......@@ -33,22 +33,7 @@ func (c *Conn) serverHandshake() error {
// If this is the first server handshake, we generate a random key to
// encrypt the tickets with.
config.serverInitOnce.Do(func() {
if config.SessionTicketsDisabled {
return
}
// If the key has already been set then we have nothing to do.
for _, b := range config.SessionTicketKey {
if b != 0 {
return
}
}
if _, err := io.ReadFull(config.rand(), config.SessionTicketKey[:]); err != nil {
config.SessionTicketsDisabled = true
}
})
config.serverInitOnce.Do(config.serverInit)
hs := serverHandshakeState{
c: c,
......
......@@ -51,6 +51,4 @@ func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) {
default:
return nil, fmt.Errorf("crypto/x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm)
}
panic("unreachable")
}
......@@ -5,6 +5,7 @@
package x509
import (
"net"
"runtime"
"strings"
"time"
......@@ -63,14 +64,28 @@ type HostnameError struct {
}
func (h HostnameError) Error() string {
var valid string
c := h.Certificate
if len(c.DNSNames) > 0 {
valid = strings.Join(c.DNSNames, ", ")
var valid string
if ip := net.ParseIP(h.Host); ip != nil {
// Trying to validate an IP
if len(c.IPAddresses) == 0 {
return "x509: cannot validate certificate for " + h.Host + " because it doesn't contain any IP SANs"
}
for _, san := range c.IPAddresses {
if len(valid) > 0 {
valid += ", "
}
valid += san.String()
}
} else {
valid = c.Subject.CommonName
if len(c.DNSNames) > 0 {
valid = strings.Join(c.DNSNames, ", ")
} else {
valid = c.Subject.CommonName
}
}
return "certificate is valid for " + valid + ", not " + h.Host
return "x509: certificate is valid for " + valid + ", not " + h.Host
}
// UnknownAuthorityError results when the certificate issuer is unknown
......@@ -334,6 +349,22 @@ func toLowerCaseASCII(in string) string {
// VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an error describing the mismatch.
func (c *Certificate) VerifyHostname(h string) error {
// IP addresses may be written in [ ].
candidateIP := h
if len(h) >= 3 && h[0] == '[' && h[len(h)-1] == ']' {
candidateIP = h[1 : len(h)-1]
}
if ip := net.ParseIP(candidateIP); ip != nil {
// We only match IP addresses against IP SANs.
// https://tools.ietf.org/html/rfc6125#appendix-B.2
for _, candidate := range c.IPAddresses {
if ip.Equal(candidate) {
return nil
}
}
return HostnameError{c, candidateIP}
}
lowered := toLowerCaseASCII(h)
if len(c.DNSNames) > 0 {
......@@ -389,6 +420,14 @@ func checkChainForKeyUsage(chain []*Certificate, keyUsages []ExtKeyUsage) bool {
for _, usage := range cert.ExtKeyUsage {
if requestedUsage == usage {
continue NextRequestedUsage
} else if requestedUsage == ExtKeyUsageServerAuth &&
(usage == ExtKeyUsageNetscapeServerGatedCrypto ||
usage == ExtKeyUsageMicrosoftServerGatedCrypto) {
// In order to support COMODO
// certificate chains, we have to
// accept Netscape or Microsoft SGC
// usages as equal to ServerAuth.
continue NextRequestedUsage
}
}
......
......@@ -158,6 +158,19 @@ var verifyTests = []verifyTest{
{"Ryan Hurst", "GlobalSign PersonalSign 2 CA - G2"},
},
},
{
leaf: megaLeaf,
intermediates: []string{comodoIntermediate1},
roots: []string{comodoRoot},
currentTime: 1360431182,
// CryptoAPI can find alternative validation paths so we don't
// perform this test with system validation.
systemSkip: true,
expectedChains: [][]string{
{"mega.co.nz", "EssentialSSL CA", "COMODO Certification Authority"},
},
},
}
func expectHostnameError(t *testing.T, i int, err error) (ok bool) {
......@@ -563,3 +576,90 @@ YEvTWbWwGdPytDFPYIl3/6OqNSXSnZ7DxPcdLJq2uyiga8PB/TTIIHYkdM2+1DE0
7y3rH/7TjwDVD7SLu5/SdOfKskuMPTjOEvz3K161mymW06klVhubCIWOro/Gx1Q2
2FQOZ7/2k4uYoOdBTSlb8kTAuzZNgIE0rB2BIYCTz/P6zZIKW0ogbRSH
-----END CERTIFICATE-----`
var megaLeaf = `-----BEGIN CERTIFICATE-----
MIIFOjCCBCKgAwIBAgIQWYE8Dup170kZ+k11Lg51OjANBgkqhkiG9w0BAQUFADBy
MQswCQYDVQQGEwJHQjEbMBkGA1UECBMSR3JlYXRlciBNYW5jaGVzdGVyMRAwDgYD
VQQHEwdTYWxmb3JkMRowGAYDVQQKExFDT01PRE8gQ0EgTGltaXRlZDEYMBYGA1UE
AxMPRXNzZW50aWFsU1NMIENBMB4XDTEyMTIxNDAwMDAwMFoXDTE0MTIxNDIzNTk1
OVowfzEhMB8GA1UECxMYRG9tYWluIENvbnRyb2wgVmFsaWRhdGVkMS4wLAYDVQQL
EyVIb3N0ZWQgYnkgSW5zdHJhIENvcnBvcmF0aW9uIFB0eS4gTFREMRUwEwYDVQQL
EwxFc3NlbnRpYWxTU0wxEzARBgNVBAMTCm1lZ2EuY28ubnowggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQDcxMCClae8BQIaJHBUIVttlLvhbK4XhXPk3RQ3
G5XA6tLZMBQ33l3F9knYJ0YErXtr8IdfYoulRQFmKFMJl9GtWyg4cGQi2Rcr5VN5
S5dA1vu4oyJBxE9fPELcK6Yz1vqaf+n6za+mYTiQYKggVdS8/s8hmNuXP9Zk1pIn
+q0pGsf8NAcSHMJgLqPQrTDw+zae4V03DvcYfNKjuno88d2226ld7MAmQZ7uRNsI
/CnkdelVs+akZsXf0szefSqMJlf08SY32t2jj4Ra7RApVYxOftD9nij/aLfuqOU6
ow6IgIcIG2ZvXLZwK87c5fxL7UAsTTV+M1sVv8jA33V2oKLhAgMBAAGjggG9MIIB
uTAfBgNVHSMEGDAWgBTay+qtWwhdzP/8JlTOSeVVxjj0+DAdBgNVHQ4EFgQUmP9l
6zhyrZ06Qj4zogt+6LKFk4AwDgYDVR0PAQH/BAQDAgWgMAwGA1UdEwEB/wQCMAAw
NAYDVR0lBC0wKwYIKwYBBQUHAwEGCCsGAQUFBwMCBgorBgEEAYI3CgMDBglghkgB
hvhCBAEwTwYDVR0gBEgwRjA6BgsrBgEEAbIxAQICBzArMCkGCCsGAQUFBwIBFh1o
dHRwczovL3NlY3VyZS5jb21vZG8uY29tL0NQUzAIBgZngQwBAgEwOwYDVR0fBDQw
MjAwoC6gLIYqaHR0cDovL2NybC5jb21vZG9jYS5jb20vRXNzZW50aWFsU1NMQ0Eu
Y3JsMG4GCCsGAQUFBwEBBGIwYDA4BggrBgEFBQcwAoYsaHR0cDovL2NydC5jb21v
ZG9jYS5jb20vRXNzZW50aWFsU1NMQ0FfMi5jcnQwJAYIKwYBBQUHMAGGGGh0dHA6
Ly9vY3NwLmNvbW9kb2NhLmNvbTAlBgNVHREEHjAcggptZWdhLmNvLm56gg53d3cu
bWVnYS5jby5uejANBgkqhkiG9w0BAQUFAAOCAQEAcYhrsPSvDuwihMOh0ZmRpbOE
Gw6LqKgLNTmaYUPQhzi2cyIjhUhNvugXQQlP5f0lp5j8cixmArafg1dTn4kQGgD3
ivtuhBTgKO1VYB/VRoAt6Lmswg3YqyiS7JiLDZxjoV7KoS5xdiaINfHDUaBBY4ZH
j2BUlPniNBjCqXe/HndUTVUewlxbVps9FyCmH+C4o9DWzdGBzDpCkcmo5nM+cp7q
ZhTIFTvZfo3zGuBoyu8BzuopCJcFRm3cRiXkpI7iOMUIixO1szkJS6WpL1sKdT73
UXp08U0LBqoqG130FbzEJBBV3ixbvY6BWMHoCWuaoF12KJnC5kHt2RoWAAgMXA==
-----END CERTIFICATE-----`
var comodoIntermediate1 = `-----BEGIN CERTIFICATE-----
MIIFAzCCA+ugAwIBAgIQGLLLuqME8aAPwfLzJkYqSjANBgkqhkiG9w0BAQUFADCB
gTELMAkGA1UEBhMCR0IxGzAZBgNVBAgTEkdyZWF0ZXIgTWFuY2hlc3RlcjEQMA4G
A1UEBxMHU2FsZm9yZDEaMBgGA1UEChMRQ09NT0RPIENBIExpbWl0ZWQxJzAlBgNV
BAMTHkNPTU9ETyBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAeFw0wNjEyMDEwMDAw
MDBaFw0xOTEyMzEyMzU5NTlaMHIxCzAJBgNVBAYTAkdCMRswGQYDVQQIExJHcmVh
dGVyIE1hbmNoZXN0ZXIxEDAOBgNVBAcTB1NhbGZvcmQxGjAYBgNVBAoTEUNPTU9E
TyBDQSBMaW1pdGVkMRgwFgYDVQQDEw9Fc3NlbnRpYWxTU0wgQ0EwggEiMA0GCSqG
SIb3DQEBAQUAA4IBDwAwggEKAoIBAQCt8AiwcsargxIxF3CJhakgEtSYau2A1NHf
5I5ZLdOWIY120j8YC0YZYwvHIPPlC92AGvFaoL0dds23Izp0XmEbdaqb1IX04XiR
0y3hr/yYLgbSeT1awB8hLRyuIVPGOqchfr7tZ291HRqfalsGs2rjsQuqag7nbWzD
ypWMN84hHzWQfdvaGlyoiBSyD8gSIF/F03/o4Tjg27z5H6Gq1huQByH6RSRQXScq
oChBRVt9vKCiL6qbfltTxfEFFld+Edc7tNkBdtzffRDPUanlOPJ7FAB1WfnwWdsX
Pvev5gItpHnBXaIcw5rIp6gLSApqLn8tl2X2xQScRMiZln5+pN0vAgMBAAGjggGD
MIIBfzAfBgNVHSMEGDAWgBQLWOWLxkwVN6RAqTCpIb5HNlpW/zAdBgNVHQ4EFgQU
2svqrVsIXcz//CZUzknlVcY49PgwDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQI
MAYBAf8CAQAwIAYDVR0lBBkwFwYKKwYBBAGCNwoDAwYJYIZIAYb4QgQBMD4GA1Ud
IAQ3MDUwMwYEVR0gADArMCkGCCsGAQUFBwIBFh1odHRwczovL3NlY3VyZS5jb21v
ZG8uY29tL0NQUzBJBgNVHR8EQjBAMD6gPKA6hjhodHRwOi8vY3JsLmNvbW9kb2Nh
LmNvbS9DT01PRE9DZXJ0aWZpY2F0aW9uQXV0aG9yaXR5LmNybDBsBggrBgEFBQcB
AQRgMF4wNgYIKwYBBQUHMAKGKmh0dHA6Ly9jcnQuY29tb2RvY2EuY29tL0NvbW9k
b1VUTlNHQ0NBLmNydDAkBggrBgEFBQcwAYYYaHR0cDovL29jc3AuY29tb2RvY2Eu
Y29tMA0GCSqGSIb3DQEBBQUAA4IBAQAtlzR6QDLqcJcvgTtLeRJ3rvuq1xqo2l/z
odueTZbLN3qo6u6bldudu+Ennv1F7Q5Slqz0J790qpL0pcRDAB8OtXj5isWMcL2a
ejGjKdBZa0wztSz4iw+SY1dWrCRnilsvKcKxudokxeRiDn55w/65g+onO7wdQ7Vu
F6r7yJiIatnyfKH2cboZT7g440LX8NqxwCPf3dfxp+0Jj1agq8MLy6SSgIGSH6lv
+Wwz3D5XxqfyH8wqfOQsTEZf6/Nh9yvENZ+NWPU6g0QO2JOsTGvMd/QDzczc4BxL
XSXaPV7Od4rhPsbXlM1wSTz/Dr0ISKvlUhQVnQ6cGodWaK2cCQBk
-----END CERTIFICATE-----`
var comodoRoot = `-----BEGIN CERTIFICATE-----
MIIEHTCCAwWgAwIBAgIQToEtioJl4AsC7j41AkblPTANBgkqhkiG9w0BAQUFADCB
gTELMAkGA1UEBhMCR0IxGzAZBgNVBAgTEkdyZWF0ZXIgTWFuY2hlc3RlcjEQMA4G
A1UEBxMHU2FsZm9yZDEaMBgGA1UEChMRQ09NT0RPIENBIExpbWl0ZWQxJzAlBgNV
BAMTHkNPTU9ETyBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAeFw0wNjEyMDEwMDAw
MDBaFw0yOTEyMzEyMzU5NTlaMIGBMQswCQYDVQQGEwJHQjEbMBkGA1UECBMSR3Jl
YXRlciBNYW5jaGVzdGVyMRAwDgYDVQQHEwdTYWxmb3JkMRowGAYDVQQKExFDT01P
RE8gQ0EgTGltaXRlZDEnMCUGA1UEAxMeQ09NT0RPIENlcnRpZmljYXRpb24gQXV0
aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0ECLi3LjkRv3
UcEbVASY06m/weaKXTuH+7uIzg3jLz8GlvCiKVCZrts7oVewdFFxze1CkU1B/qnI
2GqGd0S7WWaXUF601CxwRM/aN5VCaTwwxHGzUvAhTaHYujl8HJ6jJJ3ygxaYqhZ8
Q5sVW7euNJH+1GImGEaaP+vB+fGQV+useg2L23IwambV4EajcNxo2f8ESIl33rXp
+2dtQem8Ob0y2WIC8bGoPW43nOIv4tOiJovGuFVDiOEjPqXSJDlqR6sA1KGzqSX+
DT+nHbrTUcELpNqsOO9VUCQFZUaTNE8tja3G1CEZ0o7KBWFxB3NH5YoZEr0ETc5O
nKVIrLsm9wIDAQABo4GOMIGLMB0GA1UdDgQWBBQLWOWLxkwVN6RAqTCpIb5HNlpW
/zAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zBJBgNVHR8EQjBAMD6g
PKA6hjhodHRwOi8vY3JsLmNvbW9kb2NhLmNvbS9DT01PRE9DZXJ0aWZpY2F0aW9u
QXV0aG9yaXR5LmNybDANBgkqhkiG9w0BAQUFAAOCAQEAPpiem/Yb6dc5t3iuHXIY
SdOH5EOC6z/JqvWote9VfCFSZfnVDeFs9D6Mk3ORLgLETgdxb8CPOGEIqB6BCsAv
IC9Bi5HcSEW88cbeunZrM8gALTFGTO3nnc+IlP8zwFboJIYmuNg4ON8qa90SzMc/
RxdMosIGlgnW2/4/PEZB31jiVg88O8EckzXZOFKs7sjsLjBOlDW0JB9LeGna8gI4
zJVSk/BwJVmcIGfE7vmLV2H0knZ9P4SNVbfo5azV8fUZVqZa+5Acr5Pr5RzUZ5dd
BA6+C4OmF4O5MBKgxTMVBbkN+8cFduPYSo38NBejxiEovjBFMR7HeL5YYTisO+IB
ZQ==
-----END CERTIFICATE-----`
......@@ -19,6 +19,8 @@ import (
"errors"
"io"
"math/big"
"net"
"strconv"
"time"
)
......@@ -360,16 +362,18 @@ const (
// id-kp-timeStamping OBJECT IDENTIFIER ::= { id-kp 8 }
// id-kp-OCSPSigning OBJECT IDENTIFIER ::= { id-kp 9 }
var (
oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0}
oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}
oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}
oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}
oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}
oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5}
oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6}
oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7}
oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}
oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}
oidExtKeyUsageAny = asn1.ObjectIdentifier{2, 5, 29, 37, 0}
oidExtKeyUsageServerAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}
oidExtKeyUsageClientAuth = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}
oidExtKeyUsageCodeSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}
oidExtKeyUsageEmailProtection = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}
oidExtKeyUsageIPSECEndSystem = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5}
oidExtKeyUsageIPSECTunnel = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6}
oidExtKeyUsageIPSECUser = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7}
oidExtKeyUsageTimeStamping = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}
oidExtKeyUsageOCSPSigning = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}
oidExtKeyUsageMicrosoftServerGatedCrypto = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3}
oidExtKeyUsageNetscapeServerGatedCrypto = asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1}
)
// ExtKeyUsage represents an extended set of actions that are valid for a given key.
......@@ -387,6 +391,8 @@ const (
ExtKeyUsageIPSECUser
ExtKeyUsageTimeStamping
ExtKeyUsageOCSPSigning
ExtKeyUsageMicrosoftServerGatedCrypto
ExtKeyUsageNetscapeServerGatedCrypto
)
// extKeyUsageOIDs contains the mapping between an ExtKeyUsage and its OID.
......@@ -404,6 +410,8 @@ var extKeyUsageOIDs = []struct {
{ExtKeyUsageIPSECUser, oidExtKeyUsageIPSECUser},
{ExtKeyUsageTimeStamping, oidExtKeyUsageTimeStamping},
{ExtKeyUsageOCSPSigning, oidExtKeyUsageOCSPSigning},
{ExtKeyUsageMicrosoftServerGatedCrypto, oidExtKeyUsageMicrosoftServerGatedCrypto},
{ExtKeyUsageNetscapeServerGatedCrypto, oidExtKeyUsageNetscapeServerGatedCrypto},
}
func extKeyUsageFromOID(oid asn1.ObjectIdentifier) (eku ExtKeyUsage, ok bool) {
......@@ -458,6 +466,7 @@ type Certificate struct {
// Subject Alternate Name values
DNSNames []string
EmailAddresses []string
IPAddresses []net.IP
// Name constraints
PermittedDNSDomainsCritical bool // if true then the name constraints are marked critical.
......@@ -660,6 +669,13 @@ func parsePublicKey(algo PublicKeyAlgorithm, keyData *publicKeyInfo) (interface{
return nil, err
}
if p.N.Sign() <= 0 {
return nil, errors.New("x509: RSA modulus is not a positive number")
}
if p.E <= 0 {
return nil, errors.New("x509: RSA public exponent is not a positive number")
}
pub := &rsa.PublicKey{
E: p.E,
N: p.N,
......@@ -713,7 +729,6 @@ func parsePublicKey(algo PublicKeyAlgorithm, keyData *publicKeyInfo) (interface{
default:
return nil, nil
}
panic("unreachable")
}
func parseCertificate(in *certificate) (*Certificate, error) {
......@@ -828,6 +843,13 @@ func parseCertificate(in *certificate) (*Certificate, error) {
case 2:
out.DNSNames = append(out.DNSNames, string(v.Bytes))
parsedName = true
case 7:
switch len(v.Bytes) {
case net.IPv4len, net.IPv6len:
out.IPAddresses = append(out.IPAddresses, v.Bytes)
default:
return nil, errors.New("x509: certificate contained IP address of length " + strconv.Itoa(len(v.Bytes)))
}
}
}
......@@ -1072,11 +1094,22 @@ func buildExtensions(template *Certificate) (ret []pkix.Extension, err error) {
n++
}
if len(template.DNSNames) > 0 {
if len(template.DNSNames) > 0 || len(template.EmailAddresses) > 0 || len(template.IPAddresses) > 0 {
ret[n].Id = oidExtensionSubjectAltName
rawValues := make([]asn1.RawValue, len(template.DNSNames))
for i, name := range template.DNSNames {
rawValues[i] = asn1.RawValue{Tag: 2, Class: 2, Bytes: []byte(name)}
var rawValues []asn1.RawValue
for _, name := range template.DNSNames {
rawValues = append(rawValues, asn1.RawValue{Tag: 2, Class: 2, Bytes: []byte(name)})
}
for _, email := range template.EmailAddresses {
rawValues = append(rawValues, asn1.RawValue{Tag: 1, Class: 2, Bytes: []byte(email)})
}
for _, rawIP := range template.IPAddresses {
// If possible, we always want to encode IPv4 addresses in 4 bytes.
ip := rawIP.To4()
if ip == nil {
ip = rawIP
}
rawValues = append(rawValues, asn1.RawValue{Tag: 7, Class: 2, Bytes: ip})
}
ret[n].Value, err = asn1.Marshal(rawValues)
if err != nil {
......
......@@ -19,6 +19,7 @@ import (
"encoding/hex"
"encoding/pem"
"math/big"
"net"
"reflect"
"testing"
"time"
......@@ -174,6 +175,49 @@ func TestMatchHostnames(t *testing.T) {
}
}
func TestMatchIP(t *testing.T) {
// Check that pattern matching is working.
c := &Certificate{
DNSNames: []string{"*.foo.bar.baz"},
Subject: pkix.Name{
CommonName: "*.foo.bar.baz",
},
}
err := c.VerifyHostname("quux.foo.bar.baz")
if err != nil {
t.Fatalf("VerifyHostname(quux.foo.bar.baz): %v", err)
}
// But check that if we change it to be matching against an IP address,
// it is rejected.
c = &Certificate{
DNSNames: []string{"*.2.3.4"},
Subject: pkix.Name{
CommonName: "*.2.3.4",
},
}
err = c.VerifyHostname("1.2.3.4")
if err == nil {
t.Fatalf("VerifyHostname(1.2.3.4) should have failed, did not")
}
c = &Certificate{
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
}
err = c.VerifyHostname("127.0.0.1")
if err != nil {
t.Fatalf("VerifyHostname(127.0.0.1): %v", err)
}
err = c.VerifyHostname("::1")
if err != nil {
t.Fatalf("VerifyHostname(::1): %v", err)
}
err = c.VerifyHostname("[::1]")
if err != nil {
t.Fatalf("VerifyHostname([::1]): %v", err)
}
}
func TestCertificateParse(t *testing.T) {
s, _ := hex.DecodeString(certBytes)
certs, err := ParseCertificates(s)
......@@ -284,8 +328,11 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
UnknownExtKeyUsage: testUnknownExtKeyUsage,
BasicConstraintsValid: true,
IsCA: true,
DNSNames: []string{"test.example.com"},
IsCA: true,
DNSNames: []string{"test.example.com"},
EmailAddresses: []string{"gopher@golang.org"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1).To4(), net.ParseIP("2001:4860:0:2001::68")},
PolicyIdentifiers: []asn1.ObjectIdentifier{[]int{1, 2, 3}},
PermittedDNSDomains: []string{".example.com", "example.com"},
......@@ -327,6 +374,18 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
t.Errorf("%s: unknown extkeyusage wasn't correctly copied from the template. Got %v, want %v", test.name, cert.UnknownExtKeyUsage, testUnknownExtKeyUsage)
}
if !reflect.DeepEqual(cert.DNSNames, template.DNSNames) {
t.Errorf("%s: SAN DNS names differ from template. Got %v, want %v", test.name, cert.DNSNames, template.DNSNames)
}
if !reflect.DeepEqual(cert.EmailAddresses, template.EmailAddresses) {
t.Errorf("%s: SAN emails differ from template. Got %v, want %v", test.name, cert.EmailAddresses, template.EmailAddresses)
}
if !reflect.DeepEqual(cert.IPAddresses, template.IPAddresses) {
t.Errorf("%s: SAN IPs differ from template. Got %v, want %v", test.name, cert.IPAddresses, template.IPAddresses)
}
if test.checkSig {
err = cert.CheckSignatureFrom(cert)
if err != nil {
......
......@@ -14,12 +14,18 @@ import (
"strconv"
)
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement si may be nil, if no statement is available.
func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// The statement ds may be nil, if no statement is available.
func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
dargs := make([]driver.Value, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
}
cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter.
......@@ -58,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// column before going across the network to get the
// same error.
var err error
ds.Lock()
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
}
......@@ -75,34 +83,68 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type.
func convertAssign(dest, src interface{}) error {
// Common cases, without reflect. Fall through.
// Common cases, without reflect.
switch s := src.(type) {
case string:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = s
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = []byte(s)
return nil
}
case []byte:
switch d := dest.(type) {
case *string:
if d == nil {
return errNilPtr
}
*d = string(s)
return nil
case *interface{}:
bcopy := make([]byte, len(s))
copy(bcopy, s)
*d = bcopy
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = cloneBytes(s)
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = s
return nil
}
case nil:
switch d := dest.(type) {
case *interface{}:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *[]byte:
if d == nil {
return errNilPtr
}
*d = nil
return nil
case *RawBytes:
if d == nil {
return errNilPtr
}
*d = nil
return nil
}
......@@ -121,6 +163,26 @@ func convertAssign(dest, src interface{}) error {
*d = fmt.Sprintf("%v", src)
return nil
}
case *[]byte:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = []byte(fmt.Sprintf("%v", src))
return nil
}
case *RawBytes:
sv = reflect.ValueOf(src)
switch sv.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
*d = RawBytes(fmt.Sprintf("%v", src))
return nil
}
case *bool:
bv, err := driver.Bool.ConvertValue(src)
if err == nil {
......@@ -140,6 +202,9 @@ func convertAssign(dest, src interface{}) error {
if dpv.Kind() != reflect.Ptr {
return errors.New("destination not a pointer")
}
if dpv.IsNil() {
return errNilPtr
}
if !sv.IsValid() {
sv = reflect.ValueOf(src)
......@@ -189,6 +254,16 @@ func convertAssign(dest, src interface{}) error {
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
} else {
c := make([]byte, len(b))
copy(c, b)
return c
}
}
func asString(src interface{}) string {
switch v := src.(type) {
case string:
......
......@@ -22,6 +22,8 @@ type conversionTest struct {
wantint int64
wantuint uint64
wantstr string
wantbytes []byte
wantraw RawBytes
wantf32 float32
wantf64 float64
wanttime time.Time
......@@ -35,6 +37,8 @@ type conversionTest struct {
// Target variables for scanning into.
var (
scanstr string
scanbytes []byte
scanraw RawBytes
scanint int
scanint8 int8
scanint16 int16
......@@ -56,6 +60,7 @@ var conversionTests = []conversionTest{
{s: someTime, d: &scantime, wanttime: someTime},
// To strings
{s: "string", d: &scanstr, wantstr: "string"},
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
{s: 123, d: &scanstr, wantstr: "123"},
{s: int8(123), d: &scanstr, wantstr: "123"},
......@@ -66,6 +71,31 @@ var conversionTests = []conversionTest{
{s: uint64(123), d: &scanstr, wantstr: "123"},
{s: 1.5, d: &scanstr, wantstr: "1.5"},
// To []byte
{s: nil, d: &scanbytes, wantbytes: nil},
{s: "string", d: &scanbytes, wantbytes: []byte("string")},
{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
{s: 123, d: &scanbytes, wantbytes: []byte("123")},
{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
// To RawBytes
{s: nil, d: &scanraw, wantraw: nil},
{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
{s: 123, d: &scanraw, wantraw: RawBytes("123")},
{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
// Strings to integers
{s: "255", d: &scanuint8, wantuint: 255},
{s: "256", d: &scanuint8, wanterr: `converting string "256" to a uint8: strconv.ParseUint: parsing "256": value out of range`},
......@@ -113,6 +143,7 @@ var conversionTests = []conversionTest{
{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
{s: true, d: &scaniface, wantiface: true},
{s: nil, d: &scaniface},
{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
}
func intPtrValue(intptr interface{}) interface{} {
......@@ -191,7 +222,7 @@ func TestConversions(t *testing.T) {
}
if srcBytes, ok := ct.s.([]byte); ok {
dstBytes := (*ifptr).([]byte)
if &dstBytes[0] == &srcBytes[0] {
if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
errf("copy into interface{} didn't copy []byte data")
}
}
......
......@@ -10,8 +10,8 @@ package driver
import "errors"
// A driver Value is a value that drivers must be able to handle.
// A Value is either nil or an instance of one of these types:
// Value is a value that drivers must be able to handle.
// It is either nil or an instance of one of these types:
//
// int64
// float64
......@@ -56,7 +56,7 @@ var ErrBadConn = errors.New("driver: bad connection")
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the db package's DB.Exec will
// If a Conn does not implement Execer, the sql package's DB.Exec will
// first prepare a query, execute the statement, and then close the
// statement.
//
......@@ -65,6 +65,17 @@ type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// Queryer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Queryer, the sql package's DB.Query will
// first prepare a query, execute the statement, and then close the
// statement.
//
// Query may return ErrSkip.
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// Conn is a connection to a database. It is not used concurrently
// by multiple goroutines.
//
......@@ -104,23 +115,8 @@ type Result interface {
type Stmt interface {
// Close closes the statement.
//
// Closing a statement should not interrupt any outstanding
// query created from that statement. That is, the following
// order of operations is valid:
//
// * create a driver statement
// * call Query on statement, returning Rows
// * close the statement
// * read from Rows
//
// If closing a statement invalidates currently-running
// queries, the final step above will incorrectly fail.
//
// TODO(bradfitz): possibly remove the restriction above, if
// enough driver authors object and find it complicates their
// code too much. The sql package could be smarter about
// refcounting the statement and closing it at the appropriate
// time.
// As of Go 1.1, a Stmt will not be closed if it's in use
// by any queries.
Close() error
// NumInput returns the number of placeholder parameters.
......
......@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
"sync"
"testing"
"time"
)
......@@ -34,9 +35,10 @@ var _ = log.Printf
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
type fakeDriver struct {
mu sync.Mutex
openCount int
dbs map[string]*fakeDB
mu sync.Mutex // guards 3 following fields
openCount int // conn opens
closeCount int // conn closes
dbs map[string]*fakeDB
}
type fakeDB struct {
......@@ -229,7 +231,43 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
return c.currTx, nil
}
func (c *fakeConn) Close() error {
var hookPostCloseConn struct {
sync.Mutex
fn func(*fakeConn, error)
}
func setHookpostCloseConn(fn func(*fakeConn, error)) {
hookPostCloseConn.Lock()
defer hookPostCloseConn.Unlock()
hookPostCloseConn.fn = fn
}
var testStrictClose *testing.T
// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
// fails to close. If nil, the check is disabled.
func setStrictFakeConnClose(t *testing.T) {
testStrictClose = t
}
func (c *fakeConn) Close() (err error) {
drv := fdriver.(*fakeDriver)
defer func() {
if err != nil && testStrictClose != nil {
testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
}
hookPostCloseConn.Lock()
fn := hookPostCloseConn.fn
hookPostCloseConn.Unlock()
if fn != nil {
fn(c, err)
}
if err == nil {
drv.mu.Lock()
drv.closeCount++
drv.mu.Unlock()
}
}()
if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction")
}
......@@ -266,6 +304,18 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error
return nil, driver.ErrSkip
}
func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
return nil, driver.ErrSkip
}
func errf(msg string, args ...interface{}) error {
return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
}
......@@ -412,6 +462,12 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
}
func (s *fakeStmt) Close() error {
if s.c == nil {
panic("nil conn in fakeStmt.Close")
}
if s.c.db == nil {
panic("in fakeStmt.Close, conn's db is nil (already closed)")
}
if !s.closed {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
......@@ -503,6 +559,15 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
if s.table == "magicquery" {
if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
if args[0] == "sleep" {
time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
}
}
}
t.mu.Lock()
defer t.mu.Unlock()
......
......@@ -13,17 +13,45 @@ import (
// Data buffer being decoded.
type buf struct {
dwarf *Data
u *unit
order binary.ByteOrder
name string
off Offset
data []byte
err error
dwarf *Data
order binary.ByteOrder
format dataFormat
name string
off Offset
data []byte
err error
}
func makeBuf(d *Data, u *unit, name string, off Offset, data []byte) buf {
return buf{d, u, d.order, name, off, data, nil}
// Data format, other than byte order. This affects the handling of
// certain field formats.
type dataFormat interface {
// DWARF version number. Zero means unknown.
version() int
// 64-bit DWARF format?
dwarf64() (dwarf64 bool, isKnown bool)
// Size of an address, in bytes. Zero means unknown.
addrsize() int
}
// Some parts of DWARF have no data format, e.g., abbrevs.
type unknownFormat struct{}
func (u unknownFormat) version() int {
return 0
}
func (u unknownFormat) dwarf64() (bool, bool) {
return false, false
}
func (u unknownFormat) addrsize() int {
return 0
}
func makeBuf(d *Data, format dataFormat, name string, off Offset, data []byte) buf {
return buf{d, d.order, format, name, off, data, nil}
}
func (b *buf) uint8() uint8 {
......@@ -121,17 +149,15 @@ func (b *buf) int() int64 {
// Address-sized uint.
func (b *buf) addr() uint64 {
if b.u != nil {
switch b.u.addrsize {
case 1:
return uint64(b.uint8())
case 2:
return uint64(b.uint16())
case 4:
return uint64(b.uint32())
case 8:
return uint64(b.uint64())
}
switch b.format.addrsize() {
case 1:
return uint64(b.uint8())
case 2:
return uint64(b.uint16())
case 4:
return uint64(b.uint32())
case 8:
return uint64(b.uint64())
}
b.error("unknown address size")
return 0
......
......@@ -40,7 +40,7 @@ func (d *Data) parseAbbrev(off uint32) (abbrevTable, error) {
} else {
data = data[off:]
}
b := makeBuf(d, nil, "abbrev", 0, data)
b := makeBuf(d, unknownFormat{}, "abbrev", 0, data)
// Error handling is simplified by the buf getters
// returning an endless stream of 0s after an error.
......@@ -190,13 +190,16 @@ func (b *buf) entry(atab abbrevTable, ubase Offset) *Entry {
case formFlag:
val = b.uint8() == 1
case formFlagPresent:
// The attribute is implicitly indicated as present, and no value is
// encoded in the debugging information entry itself.
val = true
// lineptr, loclistptr, macptr, rangelistptr
case formSecOffset:
if b.u == nil {
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_sec_offset")
} else if b.u.dwarf64 {
} else if is64 {
val = Offset(b.uint64())
} else {
val = Offset(b.uint32())
......@@ -204,14 +207,20 @@ func (b *buf) entry(atab abbrevTable, ubase Offset) *Entry {
// reference to other entry
case formRefAddr:
if b.u == nil {
vers := b.format.version()
if vers == 0 {
b.error("unknown version for DW_FORM_ref_addr")
} else if b.u.version == 2 {
} else if vers == 2 {
val = Offset(b.addr())
} else if b.u.dwarf64 {
val = Offset(b.uint64())
} else {
val = Offset(b.uint32())
is64, known := b.format.dwarf64()
if !known {
b.error("unknown size for DW_FORM_ref_addr")
} else if is64 {
val = Offset(b.uint64())
} else {
val = Offset(b.uint32())
}
}
case formRef1:
val = Offset(b.uint8()) + ubase
......@@ -234,7 +243,7 @@ func (b *buf) entry(atab abbrevTable, ubase Offset) *Entry {
if b.err != nil {
return nil
}
b1 := makeBuf(b.dwarf, b.u, "str", 0, b.dwarf.str)
b1 := makeBuf(b.dwarf, unknownFormat{}, "str", 0, b.dwarf.str)
b1.skip(int(off))
val = b1.string()
if b1.err != nil {
......
......@@ -112,7 +112,7 @@ func (d *Data) readUnitLine(i int, u *unit) error {
func (d *Data) readAddressRanges(off Offset, base uint64, u *unit) error {
b := makeBuf(d, u, "ranges", off, d.ranges[off:])
var highest uint64
switch u.addrsize {
switch u.addrsize() {
case 1:
highest = 0xff
case 2:
......
......@@ -435,7 +435,9 @@ func (d *Data) Type(off Offset) (Type, error) {
goto Error
}
if loc, ok := kid.Val(AttrDataMemberLoc).([]byte); ok {
b := makeBuf(d, nil, "location", 0, loc)
// TODO: Should have original compilation
// unit here, not unknownFormat.
b := makeBuf(d, unknownFormat{}, "location", 0, loc)
if b.uint8() != opPlusUconst {
err = DecodeError{"info", kid.Offset, "unexpected opcode"}
goto Error
......
......@@ -10,17 +10,31 @@ import "strconv"
// Each unit has its own abbreviation table and address size.
type unit struct {
base Offset // byte offset of header within the aggregate info
off Offset // byte offset of data within the aggregate info
lineoff Offset // byte offset of data within the line info
data []byte
atable abbrevTable
addrsize int
version int
dwarf64 bool // True for 64-bit DWARF format
dir string
pc []addrRange // PC ranges in this compilation unit
lines []mapLineInfo // PC -> line mapping
base Offset // byte offset of header within the aggregate info
off Offset // byte offset of data within the aggregate info
lineoff Offset // byte offset of data within the line info
data []byte
atable abbrevTable
asize int
vers int
is64 bool // True for 64-bit DWARF format
dir string
pc []addrRange // PC ranges in this compilation unit
lines []mapLineInfo // PC -> line mapping
}
// Implement the dataFormat interface.
func (u *unit) version() int {
return u.vers
}
func (u *unit) dwarf64() (bool, bool) {
return u.is64, true
}
func (u *unit) addrsize() int {
return u.asize
}
// A range is an address range.
......@@ -32,12 +46,12 @@ type addrRange struct {
func (d *Data) parseUnits() ([]unit, error) {
// Count units.
nunit := 0
b := makeBuf(d, nil, "info", 0, d.info)
b := makeBuf(d, unknownFormat{}, "info", 0, d.info)
for len(b.data) > 0 {
len := b.uint32()
if len == 0xffffffff {
len64 := b.uint64()
if len64 != uint64(int(len64)) {
if len64 != uint64(uint32(len64)) {
b.error("unit length overflow")
break
}
......@@ -51,14 +65,14 @@ func (d *Data) parseUnits() ([]unit, error) {
}
// Again, this time writing them down.
b = makeBuf(d, nil, "info", 0, d.info)
b = makeBuf(d, unknownFormat{}, "info", 0, d.info)
units := make([]unit, nunit)
for i := range units {
u := &units[i]
u.base = b.off
n := b.uint32()
if n == 0xffffffff {
u.dwarf64 = true
u.is64 = true
n = uint32(b.uint64())
}
vers := b.uint16()
......@@ -66,6 +80,7 @@ func (d *Data) parseUnits() ([]unit, error) {
b.error("unsupported DWARF version " + strconv.Itoa(int(vers)))
break
}
u.vers = int(vers)
atable, err := d.parseAbbrev(b.uint32())
if err != nil {
if b.err == nil {
......@@ -73,9 +88,8 @@ func (d *Data) parseUnits() ([]unit, error) {
}
break
}
u.version = int(vers)
u.atable = atable
u.addrsize = int(b.uint8())
u.asize = int(b.uint8())
u.off = b.off
u.data = b.bytes(int(n - (2 + 4 + 1)))
}
......
......@@ -422,6 +422,10 @@ func (f *File) getSymbols32(typ SectionType) ([]Symbol, []byte, error) {
return nil, nil, errors.New("cannot load string table section")
}
// The first entry is all zeros.
var skip [Sym32Size]byte
symtab.Read(skip[:])
symbols := make([]Symbol, symtab.Len()/Sym32Size)
i := 0
......@@ -461,6 +465,10 @@ func (f *File) getSymbols64(typ SectionType) ([]Symbol, []byte, error) {
return nil, nil, errors.New("cannot load string table section")
}
// The first entry is all zeros.
var skip [Sym64Size]byte
symtab.Read(skip[:])
symbols := make([]Symbol, symtab.Len()/Sym64Size)
i := 0
......@@ -533,10 +541,10 @@ func (f *File) applyRelocationsAMD64(dst []byte, rels []byte) error {
symNo := rela.Info >> 32
t := R_X86_64(rela.Info & 0xffff)
if symNo >= uint64(len(symbols)) {
if symNo == 0 || symNo > uint64(len(symbols)) {
continue
}
sym := &symbols[symNo]
sym := &symbols[symNo-1]
if SymType(sym.Info&0xf) != STT_SECTION {
// We don't handle non-section relocations for now.
continue
......@@ -597,6 +605,10 @@ func (f *File) DWARF() (*dwarf.Data, error) {
}
// Symbols returns the symbol table for f.
//
// For compatibility with Go 1.0, Symbols omits the null symbol at index 0.
// After retrieving the symbols as symtab, an externally supplied index x
// corresponds to symtab[x-1], not symtab[x].
func (f *File) Symbols() ([]Symbol, error) {
sym, _, err := f.getSymbols(SHT_SYMTAB)
return sym, err
......@@ -706,7 +718,7 @@ func (f *File) gnuVersionInit(str []byte) {
// which came from offset i of the symbol table.
func (f *File) gnuVersion(i int, sym *ImportedSymbol) {
// Each entry is two bytes.
i = i * 2
i = (i + 1) * 2
if i >= len(f.gnuVersym) {
return
}
......
......@@ -99,31 +99,116 @@ type Table struct {
}
type sym struct {
value uint32
gotype uint32
value uint64
gotype uint64
typ byte
name []byte
}
var littleEndianSymtab = []byte{0xFE, 0xFF, 0xFF, 0xFF, 0x00, 0x00}
var littleEndianSymtab = []byte{0xFD, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00}
var bigEndianSymtab = []byte{0xFF, 0xFF, 0xFF, 0xFD, 0x00, 0x00, 0x00}
var oldLittleEndianSymtab = []byte{0xFE, 0xFF, 0xFF, 0xFF, 0x00, 0x00}
func walksymtab(data []byte, fn func(sym) error) error {
var order binary.ByteOrder = binary.BigEndian
if bytes.HasPrefix(data, littleEndianSymtab) {
newTable := false
switch {
case bytes.HasPrefix(data, oldLittleEndianSymtab):
// Same as Go 1.0, but little endian.
// Format was used during interim development between Go 1.0 and Go 1.1.
// Should not be widespread, but easy to support.
data = data[6:]
order = binary.LittleEndian
case bytes.HasPrefix(data, bigEndianSymtab):
newTable = true
case bytes.HasPrefix(data, littleEndianSymtab):
newTable = true
order = binary.LittleEndian
}
var ptrsz int
if newTable {
if len(data) < 8 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
ptrsz = int(data[7])
if ptrsz != 4 && ptrsz != 8 {
return &DecodingError{7, "invalid pointer size", ptrsz}
}
data = data[8:]
}
var s sym
p := data
for len(p) >= 6 {
s.value = order.Uint32(p[0:4])
typ := p[4]
if typ&0x80 == 0 {
return &DecodingError{len(data) - len(p) + 4, "bad symbol type", typ}
for len(p) >= 4 {
var typ byte
if newTable {
// Symbol type, value, Go type.
typ = p[0] & 0x3F
wideValue := p[0]&0x40 != 0
goType := p[0]&0x80 != 0
if typ < 26 {
typ += 'A'
} else {
typ += 'a' - 26
}
s.typ = typ
p = p[1:]
if wideValue {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width value
if ptrsz == 8 {
s.value = order.Uint64(p[0:8])
p = p[8:]
} else {
s.value = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
} else {
// varint value
s.value = 0
shift := uint(0)
for len(p) > 0 && p[0]&0x80 != 0 {
s.value |= uint64(p[0]&0x7F) << shift
shift += 7
p = p[1:]
}
if len(p) == 0 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.value |= uint64(p[0]) << shift
p = p[1:]
}
if goType {
if len(p) < ptrsz {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// fixed-width go type
if ptrsz == 8 {
s.gotype = order.Uint64(p[0:8])
p = p[8:]
} else {
s.gotype = uint64(order.Uint32(p[0:4]))
p = p[4:]
}
}
} else {
// Value, symbol type.
s.value = uint64(order.Uint32(p[0:4]))
if len(p) < 5 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
typ = p[4]
if typ&0x80 == 0 {
return &DecodingError{len(data) - len(p) + 4, "bad symbol type", typ}
}
typ &^= 0x80
s.typ = typ
p = p[5:]
}
typ &^= 0x80
s.typ = typ
p = p[5:]
// Name.
var i int
var nnul int
for i = 0; i < len(p); i++ {
......@@ -142,13 +227,21 @@ func walksymtab(data []byte, fn func(sym) error) error {
}
}
}
if i+nnul+4 > len(p) {
if len(p) < i+nnul {
return &DecodingError{len(data), "unexpected EOF", nil}
}
s.name = p[0:i]
i += nnul
s.gotype = order.Uint32(p[i : i+4])
p = p[i+4:]
p = p[i:]
if !newTable {
if len(p) < 4 {
return &DecodingError{len(data), "unexpected EOF", nil}
}
// Go type.
s.gotype = uint64(order.Uint32(p[:4]))
p = p[4:]
}
fn(s)
}
return nil
......
......@@ -142,6 +142,8 @@ type Dysymtab struct {
* Mach-O reader
*/
// FormatError is returned by some operations if the data does
// not have the correct format for an object file.
type FormatError struct {
off int64
msg string
......
......@@ -296,5 +296,4 @@ func (d *decoder) Read(p []byte) (n int, err error) {
nn, d.readErr = d.r.Read(d.buf[d.nbuf:])
d.nbuf += nn
}
panic("unreachable")
}
......@@ -460,7 +460,6 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter
default:
return marshalUTF8String(out, v.String())
}
return
}
return StructuralError{"unknown Go type"}
......
......@@ -6,8 +6,10 @@
package base32
import (
"bytes"
"io"
"strconv"
"strings"
)
/*
......@@ -48,6 +50,13 @@ var StdEncoding = NewEncoding(encodeStd)
// It is typically used in DNS.
var HexEncoding = NewEncoding(encodeHex)
var removeNewlinesMapper = func(r rune) rune {
if r == '\r' || r == '\n' {
return -1
}
return r
}
/*
* Encoder
*/
......@@ -228,40 +237,47 @@ func (e CorruptInputError) Error() string {
// decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any
// additional data is an error.
// additional data is an error. This method assumes that src has been
// stripped of all supported whitespace ('\r' and '\n').
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
osrc := src
olen := len(src)
for len(src) > 0 && !end {
// Decode quantum using the base32 alphabet
var dbuf [8]byte
dlen := 8
// do the top bytes contain any data?
for j := 0; j < 8; {
if len(src) == 0 {
return n, false, CorruptInputError(len(osrc) - len(src) - j)
return n, false, CorruptInputError(olen - len(src) - j)
}
in := src[0]
src = src[1:]
if in == '\r' || in == '\n' {
// Ignore this character.
continue
}
if in == '=' && j >= 2 && len(src) < 8 {
// We've reached the end and there's
// padding, the rest should be padded
for k := 0; k < 8-j-1; k++ {
// We've reached the end and there's padding
if len(src)+j < 8-1 {
// not enough padding
return n, false, CorruptInputError(olen)
}
for k := 0; k < 8-1-j; k++ {
if len(src) > k && src[k] != '=' {
return n, false, CorruptInputError(len(osrc) - len(src) + k - 1)
// incorrect padding
return n, false, CorruptInputError(olen - len(src) + k - 1)
}
}
dlen = j
end = true
dlen, end = j, true
// 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not
// valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing
// the five valid padding lengths, and Section 9 "Illustrations and
// Examples" for an illustration for how the the 1st, 3rd and 6th base32
// src bytes do not yield enough information to decode a dst byte.
if dlen == 1 || dlen == 3 || dlen == 6 {
return n, false, CorruptInputError(olen - len(src) - 1)
}
break
}
dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF {
return n, false, CorruptInputError(len(osrc) - len(src) - 1)
return n, false, CorruptInputError(olen - len(src) - 1)
}
j++
}
......@@ -269,16 +285,16 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Pack 8x 5-bit source blocks into 5 byte destination
// quantum
switch dlen {
case 7, 8:
case 8:
dst[4] = dbuf[6]<<5 | dbuf[7]
fallthrough
case 6, 5:
case 7:
dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
fallthrough
case 4:
case 5:
dst[2] = dbuf[3]<<4 | dbuf[4]>>1
fallthrough
case 3:
case 4:
dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
fallthrough
case 2:
......@@ -288,11 +304,11 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
switch dlen {
case 2:
n += 1
case 3, 4:
case 4:
n += 2
case 5:
n += 3
case 6, 7:
case 7:
n += 4
case 8:
n += 5
......@@ -307,12 +323,14 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
src = bytes.Map(removeNewlinesMapper, src)
n, _, err = enc.decode(dst, src)
return
}
// DecodeString returns the bytes represented by the base32 string s.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
s = strings.Map(removeNewlinesMapper, s)
dbuf := make([]byte, enc.DecodedLen(len(s)))
n, err := enc.Decode(dbuf, []byte(s))
return dbuf[:n], err
......@@ -377,9 +395,34 @@ func (d *decoder) Read(p []byte) (n int, err error) {
return n, d.err
}
type newlineFilteringReader struct {
wrapped io.Reader
}
func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
offset := 0
for i, b := range p[0:n] {
if b != '\r' && b != '\n' {
if i != offset {
p[offset] = b
}
offset++
}
}
if offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again
n, err = r.wrapped.Read(p)
}
return n, err
}
// NewDecoder constructs a new base32 stream decoder.
func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
return &decoder{enc: enc, r: r}
return &decoder{enc: enc, r: &newlineFilteringReader{r}}
}
// DecodedLen returns the maximum length in bytes of the decoded data
......
......@@ -8,6 +8,7 @@ import (
"bytes"
"io"
"io/ioutil"
"strings"
"testing"
)
......@@ -137,27 +138,48 @@ func TestDecoderBuffering(t *testing.T) {
}
func TestDecodeCorrupt(t *testing.T) {
type corrupt struct {
e string
p int
}
examples := []corrupt{
testCases := []struct {
input string
offset int // -1 means no corruption.
}{
{"", -1},
{"!!!!", 0},
{"x===", 0},
{"AA=A====", 2},
{"AAA=AAAA", 3},
{"MMMMMMMMM", 8},
{"MMMMMM", 0},
{"A=", 1},
{"AA=", 3},
{"AA==", 4},
{"AA===", 5},
{"AAAA=", 5},
{"AAAA==", 6},
{"AAAAA=", 6},
{"AAAAA==", 7},
{"A=======", 1},
{"AA======", -1},
{"AAA=====", 3},
{"AAAA====", -1},
{"AAAAA===", -1},
{"AAAAAA==", 6},
{"AAAAAAA=", -1},
{"AAAAAAAA", -1},
}
for _, e := range examples {
dbuf := make([]byte, StdEncoding.DecodedLen(len(e.e)))
_, err := StdEncoding.Decode(dbuf, []byte(e.e))
for _, tc := range testCases {
dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
_, err := StdEncoding.Decode(dbuf, []byte(tc.input))
if tc.offset == -1 {
if err != nil {
t.Error("Decoder wrongly detected coruption in", tc.input)
}
continue
}
switch err := err.(type) {
case CorruptInputError:
testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p)
testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
default:
t.Error("Decoder failed to detect corruption in", e)
t.Error("Decoder failed to detect corruption in", tc)
}
}
}
......@@ -195,9 +217,21 @@ func TestBig(t *testing.T) {
}
}
func testStringEncoding(t *testing.T, expected string, examples []string) {
for _, e := range examples {
buf, err := StdEncoding.DecodeString(e)
if err != nil {
t.Errorf("Decode(%q) failed: %v", e, err)
continue
}
if s := string(buf); s != expected {
t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
}
}
}
func TestNewLineCharacters(t *testing.T) {
// Each of these should decode to the string "sure", without errors.
const expected = "sure"
examples := []string{
"ON2XEZI=",
"ON2XEZI=\r",
......@@ -209,14 +243,44 @@ func TestNewLineCharacters(t *testing.T) {
"ON2XEZ\nI=",
"ON2XEZI\n=",
}
for _, e := range examples {
buf, err := StdEncoding.DecodeString(e)
if err != nil {
t.Errorf("Decode(%q) failed: %v", e, err)
continue
}
if s := string(buf); s != expected {
t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
}
testStringEncoding(t, "sure", examples)
// Each of these should decode to the string "foobar", without errors.
examples = []string{
"MZXW6YTBOI======",
"MZXW6YTBOI=\r\n=====",
}
testStringEncoding(t, "foobar", examples)
}
func TestDecoderIssue4779(t *testing.T) {
encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4
RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH
K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB
WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM
MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR
DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX
IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF
2HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF
NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42
LNEBUWIIDFON2CA3DBMJXXE5LNFY==
====`
encodedShort := strings.Replace(encoded, "\n", "", -1)
dec := NewDecoder(StdEncoding, bytes.NewBufferString(encoded))
res1, err := ioutil.ReadAll(dec)
if err != nil {
t.Errorf("ReadAll failed: %v", err)
}
dec = NewDecoder(StdEncoding, bytes.NewBufferString(encodedShort))
var res2 []byte
res2, err = ioutil.ReadAll(dec)
if err != nil {
t.Errorf("ReadAll failed: %v", err)
}
if !bytes.Equal(res1, res2) {
t.Error("Decoded results not equal")
}
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment