Commit 7b1c3dd9 by Ian Lance Taylor

libgo: Update to weekly.2011-12-02.

From-SVN: r182295
parent 36cfbee1
......@@ -20047,11 +20047,10 @@ var gettysburg = " Four score and seven years ago our fathers brought forth on\
"\n" +
"Abraham Lincoln, November 19, 1863, Gettysburg, Pennsylvania\n"
func main() {
m := md5.New()
io.WriteString(m, data)
hash := fmt.Sprintf("%x", m.Sum())
hash := fmt.Sprintf("%x", m.Sum(nil))
if hash != "525f06bc62a65017cd2217d7584e5920" {
println("BUG a", hash)
return
......@@ -20059,7 +20058,7 @@ func main() {
m = md5.New()
io.WriteString(m, gettysburg)
hash = fmt.Sprintf("%x", m.Sum())
hash = fmt.Sprintf("%x", m.Sum(nil))
if hash != "d7ec5d9d47a4d166091e8d9ebd7ea0aa" {
println("BUG gettysburg", hash)
println(len(gettysburg))
......@@ -19,9 +19,8 @@ func f() {
func init() {
go f()
time.Nanoseconds()
time.Now()
}
func main() {
}
b4a91b693374
0beb796b4ef8
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.
......@@ -233,7 +233,6 @@ toolexeclibgoexpdir = $(toolexeclibgodir)/exp
toolexeclibgoexp_DATA = \
exp/ebnf.gox \
exp/gui.gox \
$(exp_inotify_gox) \
exp/norm.gox \
exp/spdy.gox \
......@@ -242,11 +241,6 @@ toolexeclibgoexp_DATA = \
exp/terminal.gox \
exp/types.gox
toolexeclibgoexpguidir = $(toolexeclibgoexpdir)/gui
toolexeclibgoexpgui_DATA = \
exp/gui/x11.gox
toolexeclibgoexpsqldir = $(toolexeclibgoexpdir)/sql
toolexeclibgoexpsql_DATA = \
......@@ -447,6 +441,7 @@ runtime_files = \
runtime/go-map-len.c \
runtime/go-map-range.c \
runtime/go-nanotime.c \
runtime/go-now.c \
runtime/go-new-map.c \
runtime/go-new.c \
runtime/go-panic.c \
......@@ -576,6 +571,7 @@ go_hash_files = \
go_html_files = \
go/html/const.go \
go/html/doc.go \
go/html/doctype.go \
go/html/entity.go \
go/html/escape.go \
go/html/node.go \
......@@ -888,7 +884,7 @@ go_time_files = \
go/time/sys_unix.go \
go/time/tick.go \
go/time/time.go \
go/time/zoneinfo_posix.go \
go/time/zoneinfo.go \
go/time/zoneinfo_unix.go
go_unicode_files = \
......@@ -1038,6 +1034,7 @@ go_crypto_twofish_files = \
go_crypto_x509_files = \
go/crypto/x509/cert_pool.go \
go/crypto/x509/pkcs1.go \
go/crypto/x509/pkcs8.go \
go/crypto/x509/verify.go \
go/crypto/x509/x509.go
go_crypto_xtea_files = \
......@@ -1135,8 +1132,6 @@ go_encoding_xml_files = \
go_exp_ebnf_files = \
go/exp/ebnf/ebnf.go \
go/exp/ebnf/parser.go
go_exp_gui_files = \
go/exp/gui/gui.go
go_exp_inotify_files = \
go/exp/inotify/inotify_linux.go
go_exp_norm_files = \
......@@ -1178,10 +1173,6 @@ go_exp_types_files = \
go/exp/types/types.go \
go/exp/types/universe.go
go_exp_gui_x11_files = \
go/exp/gui/x11/auth.go \
go/exp/gui/x11/conn.go
go_exp_sql_driver_files = \
go/exp/sql/driver/driver.go \
go/exp/sql/driver/types.go
......@@ -1415,13 +1406,11 @@ go_text_template_files = \
go/text/template/exec.go \
go/text/template/funcs.go \
go/text/template/helper.go \
go/text/template/parse.go \
go/text/template/set.go
go/text/template/template.go
go_text_template_parse_files = \
go/text/template/parse/lex.go \
go/text/template/parse/node.go \
go/text/template/parse/parse.go \
go/text/template/parse/set.go
go/text/template/parse/parse.go
go_sync_atomic_files = \
go/sync/atomic/doc.go
......@@ -1725,14 +1714,12 @@ libgo_go_objs = \
encoding/pem.lo \
encoding/xml.lo \
exp/ebnf.lo \
exp/gui.lo \
exp/norm.lo \
exp/spdy.lo \
exp/sql.lo \
exp/ssh.lo \
exp/terminal.lo \
exp/types.lo \
exp/gui/x11.lo \
exp/sql/driver.lo \
html/template.lo \
go/ast.lo \
......@@ -2784,16 +2771,6 @@ exp/ebnf/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: exp/ebnf/check
@go_include@ exp/gui.lo.dep
exp/gui.lo.dep: $(go_exp_gui_files)
$(BUILDDEPS)
exp/gui.lo: $(go_exp_gui_files)
$(BUILDPACKAGE)
exp/gui/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/gui
@$(CHECK)
.PHONY: exp/gui/check
@go_include@ exp/norm.lo.dep
exp/norm.lo.dep: $(go_exp_norm_files)
$(BUILDDEPS)
......@@ -2854,16 +2831,6 @@ exp/types/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: exp/types/check
@go_include@ exp/gui/x11.lo.dep
exp/gui/x11.lo.dep: $(go_exp_gui_x11_files)
$(BUILDDEPS)
exp/gui/x11.lo: $(go_exp_gui_x11_files)
$(BUILDPACKAGE)
exp/gui/x11/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/gui/x11
@$(CHECK)
.PHONY: exp/gui/x11/check
@go_include@ exp/inotify.lo.dep
exp/inotify.lo.dep: $(go_exp_inotify_files)
$(BUILDDEPS)
......@@ -3686,8 +3653,6 @@ encoding/xml.gox: encoding/xml.lo
exp/ebnf.gox: exp/ebnf.lo
$(BUILDGOX)
exp/gui.gox: exp/gui.lo
$(BUILDGOX)
exp/inotify.gox: exp/inotify.lo
$(BUILDGOX)
exp/norm.gox: exp/norm.lo
......@@ -3703,9 +3668,6 @@ exp/terminal.gox: exp/terminal.lo
exp/types.gox: exp/types.lo
$(BUILDGOX)
exp/gui/x11.gox: exp/gui/x11.lo
$(BUILDGOX)
exp/sql/driver.gox: exp/sql/driver.lo
$(BUILDGOX)
......@@ -3950,6 +3912,7 @@ TEST_PACKAGES = \
html/template/check \
go/ast/check \
$(go_build_check_omitted_since_it_calls_6g) \
go/doc/check \
go/parser/check \
go/printer/check \
go/scanner/check \
......
......@@ -11,41 +11,42 @@
// http://www.gnu.org/software/tar/manual/html_node/Standard.html
package tar
import "time"
const (
blockSize = 512
// Types
TypeReg = '0' // regular file.
TypeRegA = '\x00' // regular file.
TypeLink = '1' // hard link.
TypeSymlink = '2' // symbolic link.
TypeChar = '3' // character device node.
TypeBlock = '4' // block device node.
TypeDir = '5' // directory.
TypeFifo = '6' // fifo node.
TypeCont = '7' // reserved.
TypeXHeader = 'x' // extended header.
TypeXGlobalHeader = 'g' // global extended header.
TypeReg = '0' // regular file
TypeRegA = '\x00' // regular file
TypeLink = '1' // hard link
TypeSymlink = '2' // symbolic link
TypeChar = '3' // character device node
TypeBlock = '4' // block device node
TypeDir = '5' // directory
TypeFifo = '6' // fifo node
TypeCont = '7' // reserved
TypeXHeader = 'x' // extended header
TypeXGlobalHeader = 'g' // global extended header
)
// A Header represents a single header in a tar archive.
// Some fields may not be populated.
type Header struct {
Name string // name of header file entry.
Mode int64 // permission and mode bits.
Uid int // user id of owner.
Gid int // group id of owner.
Size int64 // length in bytes.
Mtime int64 // modified time; seconds since epoch.
Typeflag byte // type of header entry.
Linkname string // target name of link.
Uname string // user name of owner.
Gname string // group name of owner.
Devmajor int64 // major number of character or block device.
Devminor int64 // minor number of character or block device.
Atime int64 // access time; seconds since epoch.
Ctime int64 // status change time; seconds since epoch.
Name string // name of header file entry
Mode int64 // permission and mode bits
Uid int // user id of owner
Gid int // group id of owner
Size int64 // length in bytes
ModTime time.Time // modified time
Typeflag byte // type of header entry
Linkname string // target name of link
Uname string // user name of owner
Gname string // group name of owner
Devmajor int64 // major number of character or block device
Devminor int64 // minor number of character or block device
AccessTime time.Time // access time
ChangeTime time.Time // status change time
}
var zeroBlock = make([]byte, blockSize)
......
......@@ -14,6 +14,7 @@ import (
"io/ioutil"
"os"
"strconv"
"time"
)
var (
......@@ -141,7 +142,7 @@ func (tr *Reader) readHeader() *Header {
hdr.Uid = int(tr.octal(s.next(8)))
hdr.Gid = int(tr.octal(s.next(8)))
hdr.Size = tr.octal(s.next(12))
hdr.Mtime = tr.octal(s.next(12))
hdr.ModTime = time.Unix(tr.octal(s.next(12)), 0)
s.next(8) // chksum
hdr.Typeflag = s.next(1)[0]
hdr.Linkname = cString(s.next(100))
......@@ -178,8 +179,8 @@ func (tr *Reader) readHeader() *Header {
prefix = cString(s.next(155))
case "star":
prefix = cString(s.next(131))
hdr.Atime = tr.octal(s.next(12))
hdr.Ctime = tr.octal(s.next(12))
hdr.AccessTime = time.Unix(tr.octal(s.next(12)), 0)
hdr.ChangeTime = time.Unix(tr.octal(s.next(12)), 0)
}
if len(prefix) > 0 {
hdr.Name = prefix + "/" + hdr.Name
......
......@@ -12,6 +12,7 @@ import (
"os"
"reflect"
"testing"
"time"
)
type untarTest struct {
......@@ -29,7 +30,7 @@ var gnuTarTest = &untarTest{
Uid: 73025,
Gid: 5000,
Size: 5,
Mtime: 1244428340,
ModTime: time.Unix(1244428340, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
......@@ -40,7 +41,7 @@ var gnuTarTest = &untarTest{
Uid: 73025,
Gid: 5000,
Size: 11,
Mtime: 1244436044,
ModTime: time.Unix(1244436044, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
......@@ -58,30 +59,30 @@ var untarTests = []*untarTest{
file: "testdata/star.tar",
headers: []*Header{
&Header{
Name: "small.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 5,
Mtime: 1244592783,
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
Atime: 1244592783,
Ctime: 1244592783,
Name: "small.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 5,
ModTime: time.Unix(1244592783, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
AccessTime: time.Unix(1244592783, 0),
ChangeTime: time.Unix(1244592783, 0),
},
&Header{
Name: "small2.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 11,
Mtime: 1244592783,
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
Atime: 1244592783,
Ctime: 1244592783,
Name: "small2.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 11,
ModTime: time.Unix(1244592783, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
AccessTime: time.Unix(1244592783, 0),
ChangeTime: time.Unix(1244592783, 0),
},
},
},
......@@ -94,7 +95,7 @@ var untarTests = []*untarTest{
Uid: 73025,
Gid: 5000,
Size: 5,
Mtime: 1244593104,
ModTime: time.Unix(1244593104, 0),
Typeflag: '\x00',
},
&Header{
......@@ -103,7 +104,7 @@ var untarTests = []*untarTest{
Uid: 73025,
Gid: 5000,
Size: 11,
Mtime: 1244593104,
ModTime: time.Unix(1244593104, 0),
Typeflag: '\x00',
},
},
......@@ -221,7 +222,7 @@ func TestIncrementalRead(t *testing.T) {
h.Write(rdbuf[0:nr])
}
// verify checksum
have := fmt.Sprintf("%x", h.Sum())
have := fmt.Sprintf("%x", h.Sum(nil))
want := cksums[nread]
if want != have {
t.Errorf("Bad checksum on file %s:\nhave %+v\nwant %+v", hdr.Name, have, want)
......
......@@ -127,19 +127,19 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
// TODO(dsymonds): handle names longer than 100 chars
copy(s.next(100), []byte(hdr.Name))
tw.octal(s.next(8), hdr.Mode) // 100:108
tw.numeric(s.next(8), int64(hdr.Uid)) // 108:116
tw.numeric(s.next(8), int64(hdr.Gid)) // 116:124
tw.numeric(s.next(12), hdr.Size) // 124:136
tw.numeric(s.next(12), hdr.Mtime) // 136:148
s.next(8) // chksum (148:156)
s.next(1)[0] = hdr.Typeflag // 156:157
tw.cString(s.next(100), hdr.Linkname) // linkname (157:257)
copy(s.next(8), []byte("ustar\x0000")) // 257:265
tw.cString(s.next(32), hdr.Uname) // 265:297
tw.cString(s.next(32), hdr.Gname) // 297:329
tw.numeric(s.next(8), hdr.Devmajor) // 329:337
tw.numeric(s.next(8), hdr.Devminor) // 337:345
tw.octal(s.next(8), hdr.Mode) // 100:108
tw.numeric(s.next(8), int64(hdr.Uid)) // 108:116
tw.numeric(s.next(8), int64(hdr.Gid)) // 116:124
tw.numeric(s.next(12), hdr.Size) // 124:136
tw.numeric(s.next(12), hdr.ModTime.Unix()) // 136:148
s.next(8) // chksum (148:156)
s.next(1)[0] = hdr.Typeflag // 156:157
tw.cString(s.next(100), hdr.Linkname) // linkname (157:257)
copy(s.next(8), []byte("ustar\x0000")) // 257:265
tw.cString(s.next(32), hdr.Uname) // 265:297
tw.cString(s.next(32), hdr.Gname) // 297:329
tw.numeric(s.next(8), hdr.Devmajor) // 329:337
tw.numeric(s.next(8), hdr.Devminor) // 337:345
// Use the GNU magic instead of POSIX magic if we used any GNU extensions.
if tw.usedBinary {
......
......@@ -11,6 +11,7 @@ import (
"io/ioutil"
"testing"
"testing/iotest"
"time"
)
type writerTestEntry struct {
......@@ -38,7 +39,7 @@ var writerTests = []*writerTest{
Uid: 73025,
Gid: 5000,
Size: 5,
Mtime: 1246508266,
ModTime: time.Unix(1246508266, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
......@@ -52,7 +53,7 @@ var writerTests = []*writerTest{
Uid: 73025,
Gid: 5000,
Size: 11,
Mtime: 1245217492,
ModTime: time.Unix(1245217492, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
......@@ -66,7 +67,7 @@ var writerTests = []*writerTest{
Uid: 1000,
Gid: 1000,
Size: 0,
Mtime: 1314603082,
ModTime: time.Unix(1314603082, 0),
Typeflag: '2',
Linkname: "small.txt",
Uname: "strings",
......@@ -89,7 +90,7 @@ var writerTests = []*writerTest{
Uid: 73025,
Gid: 5000,
Size: 16 << 30,
Mtime: 1254699560,
ModTime: time.Unix(1254699560, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
......
......@@ -56,7 +56,7 @@ func OpenReader(name string) (*ReadCloser, error) {
return nil, err
}
r := new(ReadCloser)
if err := r.init(f, fi.Size); err != nil {
if err := r.init(f, fi.Size()); err != nil {
f.Close()
return nil, err
}
......
......@@ -164,8 +164,8 @@ func readTestFile(t *testing.T, ft ZipTestFile, f *File) {
t.Error(err)
return
}
if got, want := f.Mtime_ns()/1e9, mtime.Seconds(); got != want {
t.Errorf("%s: mtime=%s (%d); want %s (%d)", f.Name, time.SecondsToUTC(got), got, mtime, want)
if ft := f.ModTime(); !ft.Equal(mtime) {
t.Errorf("%s: mtime=%s, want %s", f.Name, ft, mtime)
}
testFileMode(t, f, ft.Mode)
......
......@@ -11,8 +11,10 @@ This package does not support ZIP64 or disk spanning.
*/
package zip
import "errors"
import "time"
import (
"errors"
"time"
)
// Compression methods.
const (
......@@ -74,24 +76,26 @@ func recoverError(errp *error) {
// The resolution is 2s.
// See: http://msdn.microsoft.com/en-us/library/ms724247(v=VS.85).aspx
func msDosTimeToTime(dosDate, dosTime uint16) time.Time {
return time.Time{
return time.Date(
// date bits 0-4: day of month; 5-8: month; 9-15: years since 1980
Year: int64(dosDate>>9 + 1980),
Month: int(dosDate >> 5 & 0xf),
Day: int(dosDate & 0x1f),
int(dosDate>>9+1980),
time.Month(dosDate>>5&0xf),
int(dosDate&0x1f),
// time bits 0-4: second/2; 5-10: minute; 11-15: hour
Hour: int(dosTime >> 11),
Minute: int(dosTime >> 5 & 0x3f),
Second: int(dosTime & 0x1f * 2),
}
int(dosTime>>11),
int(dosTime>>5&0x3f),
int(dosTime&0x1f*2),
0, // nanoseconds
time.UTC,
)
}
// Mtime_ns returns the modified time in ns since epoch.
// ModTime returns the modification time.
// The resolution is 2s.
func (h *FileHeader) Mtime_ns() int64 {
t := msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
return t.Seconds() * 1e9
func (h *FileHeader) ModTime() time.Time {
return msDosTimeToTime(h.ModifiedDate, h.ModifiedTime)
}
// Mode returns the permission and mode bits for the FileHeader.
......
......@@ -702,7 +702,7 @@ func TestTrim(t *testing.T) {
case "TrimRight":
f = TrimRight
default:
t.Error("Undefined trim function %s", name)
t.Errorf("Undefined trim function %s", name)
}
actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out {
......
......@@ -13,6 +13,7 @@ import (
"hash"
"hash/crc32"
"io"
"time"
)
// BUG(nigeltao): Comments and Names don't properly map UTF-8 character codes outside of
......@@ -42,11 +43,11 @@ var ChecksumError = errors.New("gzip checksum error")
// The gzip file stores a header giving metadata about the compressed file.
// That header is exposed as the fields of the Compressor and Decompressor structs.
type Header struct {
Comment string // comment
Extra []byte // "extra data"
Mtime uint32 // modification time (seconds since January 1, 1970)
Name string // file name
OS byte // operating system type
Comment string // comment
Extra []byte // "extra data"
ModTime time.Time // modification time
Name string // file name
OS byte // operating system type
}
// An Decompressor is an io.Reader that can be read to retrieve
......@@ -130,7 +131,7 @@ func (z *Decompressor) readHeader(save bool) error {
}
z.flg = z.buf[3]
if save {
z.Mtime = get4(z.buf[4:8])
z.ModTime = time.Unix(int64(get4(z.buf[4:8])), 0)
// z.buf[8] is xfl, ignored
z.OS = z.buf[9]
}
......
......@@ -122,7 +122,7 @@ func (z *Compressor) Write(p []byte) (int, error) {
if z.Comment != "" {
z.buf[3] |= 0x10
}
put4(z.buf[4:8], z.Mtime)
put4(z.buf[4:8], uint32(z.ModTime.Unix()))
if z.level == BestCompression {
z.buf[8] = 2
} else if z.level == BestSpeed {
......
......@@ -8,6 +8,7 @@ import (
"io"
"io/ioutil"
"testing"
"time"
)
// pipe creates two ends of a pipe that gzip and gunzip, and runs dfunc at the
......@@ -53,7 +54,7 @@ func TestWriter(t *testing.T) {
func(compressor *Compressor) {
compressor.Comment = "comment"
compressor.Extra = []byte("extra")
compressor.Mtime = 1e8
compressor.ModTime = time.Unix(1e8, 0)
compressor.Name = "name"
_, err := compressor.Write([]byte("payload"))
if err != nil {
......@@ -74,8 +75,8 @@ func TestWriter(t *testing.T) {
if string(decompressor.Extra) != "extra" {
t.Fatalf("extra is %q, want %q", decompressor.Extra, "extra")
}
if decompressor.Mtime != 1e8 {
t.Fatalf("mtime is %d, want %d", decompressor.Mtime, uint32(1e8))
if decompressor.ModTime.Unix() != 1e8 {
t.Fatalf("mtime is %d, want %d", decompressor.ModTime.Unix(), uint32(1e8))
}
if decompressor.Name != "name" {
t.Fatalf("name is %q, want %q", decompressor.Name, "name")
......
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package bcrypt implements Provos and Mazières's bcrypt adapative hashing
// Package bcrypt implements Provos and Mazières's bcrypt adaptive hashing
// algorithm. See http://www.usenix.org/event/usenix99/provos/provos.pdf
package bcrypt
......
......@@ -214,7 +214,7 @@ func TestVectors(t *testing.T) {
msg, _ := hex.DecodeString(test.msg)
sha.Reset()
sha.Write(msg)
hashed := sha.Sum()
hashed := sha.Sum(nil)
r := fromHex(test.r)
s := fromHex(test.s)
if Verify(&pub, hashed, r, s) != test.ok {
......
......@@ -48,15 +48,15 @@ func (h *hmac) tmpPad(xor byte) {
}
}
func (h *hmac) Sum() []byte {
sum := h.inner.Sum()
func (h *hmac) Sum(in []byte) []byte {
sum := h.inner.Sum(nil)
h.tmpPad(0x5c)
for i, b := range sum {
h.tmp[padSize+i] = b
}
h.outer.Reset()
h.outer.Write(h.tmp)
return h.outer.Sum()
return h.outer.Sum(in)
}
func (h *hmac) Write(p []byte) (n int, err error) {
......@@ -81,7 +81,7 @@ func New(h func() hash.Hash, key []byte) hash.Hash {
if len(key) > padSize {
// If key is too big, hash it.
hm.outer.Write(key)
key = hm.outer.Sum()
key = hm.outer.Sum(nil)
}
hm.key = make([]byte, len(key))
copy(hm.key, key)
......
......@@ -192,7 +192,7 @@ func TestHMAC(t *testing.T) {
// Repetitive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum())
sum := fmt.Sprintf("%x", h.Sum(nil))
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s\n", i, j, k, sum, tt.out)
}
......
......@@ -77,7 +77,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0, so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -103,14 +103,11 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 16)
j := 0
for _, s := range d.s {
p[j+0] = byte(s >> 0)
p[j+1] = byte(s >> 8)
p[j+2] = byte(s >> 16)
p[j+3] = byte(s >> 24)
j += 4
in = append(in, byte(s>>0))
in = append(in, byte(s>>8))
in = append(in, byte(s>>16))
in = append(in, byte(s>>24))
}
return p
return in
}
......@@ -58,10 +58,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("md4[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......
......@@ -77,7 +77,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -103,14 +103,11 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 16)
j := 0
for _, s := range d.s {
p[j+0] = byte(s >> 0)
p[j+1] = byte(s >> 8)
p[j+2] = byte(s >> 16)
p[j+3] = byte(s >> 24)
j += 4
in = append(in, byte(s>>0))
in = append(in, byte(s>>8))
in = append(in, byte(s>>16))
in = append(in, byte(s>>24))
}
return p
return in
}
......@@ -58,10 +58,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("md5[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......
......@@ -61,7 +61,7 @@ type responseData struct {
Version int `asn1:"optional,default:1,explicit,tag:0"`
RequestorName pkix.RDNSequence `asn1:"optional,explicit,tag:1"`
KeyHash []byte `asn1:"optional,explicit,tag:2"`
ProducedAt *time.Time
ProducedAt time.Time
Responses []singleResponse
}
......@@ -70,12 +70,12 @@ type singleResponse struct {
Good asn1.Flag `asn1:"explicit,tag:0,optional"`
Revoked revokedInfo `asn1:"explicit,tag:1,optional"`
Unknown asn1.Flag `asn1:"explicit,tag:2,optional"`
ThisUpdate *time.Time
NextUpdate *time.Time `asn1:"explicit,tag:0,optional"`
ThisUpdate time.Time
NextUpdate time.Time `asn1:"explicit,tag:0,optional"`
}
type revokedInfo struct {
RevocationTime *time.Time
RevocationTime time.Time
Reason int `asn1:"explicit,tag:0,optional"`
}
......@@ -97,7 +97,7 @@ type Response struct {
// Status is one of {Good, Revoked, Unknown, ServerFailed}
Status int
SerialNumber []byte
ProducedAt, ThisUpdate, NextUpdate, RevokedAt *time.Time
ProducedAt, ThisUpdate, NextUpdate, RevokedAt time.Time
RevocationReason int
Certificate *x509.Certificate
}
......@@ -161,7 +161,7 @@ func ParseResponse(bytes []byte) (*Response, error) {
pub := ret.Certificate.PublicKey.(*rsa.PublicKey)
h.Write(basicResp.TBSResponseData.Raw)
digest := h.Sum()
digest := h.Sum(nil)
signature := basicResp.Signature.RightAlign()
if rsa.VerifyPKCS1v15(pub, hashType, digest, signature) != nil {
......
......@@ -15,7 +15,13 @@ func TestOCSPDecode(t *testing.T) {
t.Error(err)
}
expected := Response{Status: 0, SerialNumber: []byte{0x1, 0xd0, 0xfa}, RevocationReason: 0, ThisUpdate: &time.Time{Year: 2010, Month: 7, Day: 7, Hour: 15, Minute: 1, Second: 5, ZoneOffset: 0, Zone: "UTC"}, NextUpdate: &time.Time{Year: 2010, Month: 7, Day: 7, Hour: 18, Minute: 35, Second: 17, ZoneOffset: 0, Zone: "UTC"}}
expected := Response{
Status: 0,
SerialNumber: []byte{0x1, 0xd0, 0xfa},
RevocationReason: 0,
ThisUpdate: time.Date(2010, 7, 7, 15, 1, 5, 0, time.UTC),
NextUpdate: time.Date(2010, 7, 7, 18, 35, 17, 0, time.UTC),
}
if !reflect.DeepEqual(resp.ThisUpdate, resp.ThisUpdate) {
t.Errorf("resp.ThisUpdate: got %d, want %d", resp.ThisUpdate, expected.ThisUpdate)
......
......@@ -41,8 +41,8 @@ func (cth *canonicalTextHash) Write(buf []byte) (int, error) {
return len(buf), nil
}
func (cth *canonicalTextHash) Sum() []byte {
return cth.h.Sum()
func (cth *canonicalTextHash) Sum(in []byte) []byte {
return cth.h.Sum(in)
}
func (cth *canonicalTextHash) Reset() {
......
......@@ -17,8 +17,8 @@ func (r recordingHash) Write(b []byte) (n int, err error) {
return r.buf.Write(b)
}
func (r recordingHash) Sum() []byte {
return r.buf.Bytes()
func (r recordingHash) Sum(in []byte) []byte {
return append(in, r.buf.Bytes()...)
}
func (r recordingHash) Reset() {
......@@ -33,7 +33,7 @@ func testCanonicalText(t *testing.T, input, expected string) {
r := recordingHash{bytes.NewBuffer(nil)}
c := NewCanonicalTextHash(r)
c.Write([]byte(input))
result := c.Sum()
result := c.Sum(nil)
if expected != string(result) {
t.Errorf("input: %x got: %x want: %x", input, result, expected)
}
......
......@@ -381,7 +381,7 @@ const defaultRSAKeyBits = 2048
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
// single identity composed of the given full name, comment and email, any of
// which may be empty but must not contain any of "()<>\x00".
func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email string) (*Entity, error) {
func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) {
uid := packet.NewUserId(name, comment, email)
if uid == nil {
return nil, error_.InvalidArgumentError("user id field contained invalid characters")
......@@ -395,11 +395,9 @@ func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email strin
return nil, err
}
t := uint32(currentTimeSecs)
e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(t, &signingPriv.PublicKey, false /* not a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, signingPriv, false /* not a subkey */ ),
PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey, false /* not a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv, false /* not a subkey */ ),
Identities: make(map[string]*Identity),
}
isPrimaryId := true
......@@ -407,7 +405,7 @@ func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email strin
Name: uid.Name,
UserId: uid,
SelfSignature: &packet.Signature{
CreationTime: t,
CreationTime: currentTime,
SigType: packet.SigTypePositiveCert,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
......@@ -421,10 +419,10 @@ func NewEntity(rand io.Reader, currentTimeSecs int64, name, comment, email strin
e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(t, &encryptingPriv.PublicKey, true /* is a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(t, encryptingPriv, true /* is a subkey */ ),
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey, true /* is a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv, true /* is a subkey */ ),
Sig: &packet.Signature{
CreationTime: t,
CreationTime: currentTime,
SigType: packet.SigTypeSubkeyBinding,
PubKeyAlgo: packet.PubKeyAlgoRSA,
Hash: crypto.SHA256,
......@@ -533,7 +531,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error {
SigType: packet.SigTypeGenericCert,
PubKeyAlgo: signer.PrivateKey.PubKeyAlgo,
Hash: crypto.SHA256,
CreationTime: uint32(time.Seconds()),
CreationTime: time.Now(),
IssuerKeyId: &signer.PrivateKey.KeyId,
}
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
......
......@@ -17,6 +17,7 @@ import (
"io/ioutil"
"math/big"
"strconv"
"time"
)
// PrivateKey represents a possibly encrypted private key. See RFC 4880,
......@@ -32,9 +33,9 @@ type PrivateKey struct {
iv []byte
}
func NewRSAPrivateKey(currentTimeSecs uint32, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTimeSecs, &priv.PublicKey, isSubkey)
pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey, isSubkey)
pk.PrivateKey = priv
return pk
}
......@@ -99,13 +100,9 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
}
func mod64kHash(d []byte) uint16 {
h := uint16(0)
for i := 0; i < len(d); i += 2 {
v := uint16(d[i]) << 8
if i+1 < len(d) {
v += uint16(d[i+1])
}
h += v
var h uint16
for _, b := range d {
h += uint16(b)
}
return h
}
......@@ -195,7 +192,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
}
h := sha1.New()
h.Write(data[:len(data)-sha1.Size])
sum := h.Sum()
sum := h.Sum(nil)
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
return error_.StructuralError("private key checksum failure")
}
......
......@@ -6,19 +6,20 @@ package packet
import (
"testing"
"time"
)
var privateKeyTests = []struct {
privateKeyHex string
creationTime uint32
creationTime time.Time
}{
{
privKeyRSAHex,
0x4cc349a8,
time.Unix(0x4cc349a8, 0),
},
{
privKeyElGamalHex,
0x4df9ee1a,
time.Unix(0x4df9ee1a, 0),
},
}
......@@ -43,7 +44,7 @@ func TestPrivateKeyRead(t *testing.T) {
continue
}
if privKey.CreationTime != test.creationTime || privKey.Encrypted {
if !privKey.CreationTime.Equal(test.creationTime) || privKey.Encrypted {
t.Errorf("#%d: bad result, got: %#v", i, privKey)
}
}
......
......@@ -16,11 +16,12 @@ import (
"io"
"math/big"
"strconv"
"time"
)
// PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2.
type PublicKey struct {
CreationTime uint32 // seconds since the epoch
CreationTime time.Time
PubKeyAlgo PublicKeyAlgorithm
PublicKey interface{} // Either a *rsa.PublicKey or *dsa.PublicKey
Fingerprint [20]byte
......@@ -38,9 +39,9 @@ func fromBig(n *big.Int) parsedMPI {
}
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTimeSecs uint32, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
pk := &PublicKey{
CreationTime: creationTimeSecs,
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub,
IsSubkey: isSubkey,
......@@ -62,7 +63,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
if buf[0] != 4 {
return error_.UnsupportedError("public key version")
}
pk.CreationTime = uint32(buf[1])<<24 | uint32(buf[2])<<16 | uint32(buf[3])<<8 | uint32(buf[4])
pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
......@@ -87,7 +88,7 @@ func (pk *PublicKey) setFingerPrintAndKeyId() {
fingerPrint := sha1.New()
pk.SerializeSignaturePrefix(fingerPrint)
pk.serializeWithoutHeaders(fingerPrint)
copy(pk.Fingerprint[:], fingerPrint.Sum())
copy(pk.Fingerprint[:], fingerPrint.Sum(nil))
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
}
......@@ -234,10 +235,11 @@ func (pk *PublicKey) Serialize(w io.Writer) (err error) {
func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
var buf [6]byte
buf[0] = 4
buf[1] = byte(pk.CreationTime >> 24)
buf[2] = byte(pk.CreationTime >> 16)
buf[3] = byte(pk.CreationTime >> 8)
buf[4] = byte(pk.CreationTime)
t := uint32(pk.CreationTime.Unix())
buf[1] = byte(t >> 24)
buf[2] = byte(t >> 16)
buf[3] = byte(t >> 8)
buf[4] = byte(t)
buf[5] = byte(pk.PubKeyAlgo)
_, err = w.Write(buf[:])
......@@ -269,7 +271,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
}
signed.Write(sig.HashSuffix)
hashBytes := signed.Sum()
hashBytes := signed.Sum(nil)
if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
return error_.SignatureError("hash tag doesn't match")
......
......@@ -8,19 +8,20 @@ import (
"bytes"
"encoding/hex"
"testing"
"time"
)
var pubKeyTests = []struct {
hexData string
hexFingerprint string
creationTime uint32
creationTime time.Time
pubKeyAlgo PublicKeyAlgorithm
keyId uint64
keyIdString string
keyIdShort string
}{
{rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb, "A34D7E18C20C31BB", "C20C31BB"},
{dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed, "8E8FBE54062F19ED", "062F19ED"},
{rsaPkDataHex, rsaFingerprintHex, time.Unix(0x4d3c5c10, 0), PubKeyAlgoRSA, 0xa34d7e18c20c31bb, "A34D7E18C20C31BB", "C20C31BB"},
{dsaPkDataHex, dsaFingerprintHex, time.Unix(0x4d432f89, 0), PubKeyAlgoDSA, 0x8e8fbe54062f19ed, "8E8FBE54062F19ED", "062F19ED"},
}
func TestPublicKeyRead(t *testing.T) {
......@@ -38,8 +39,8 @@ func TestPublicKeyRead(t *testing.T) {
if pk.PubKeyAlgo != test.pubKeyAlgo {
t.Errorf("#%d: bad public key algorithm got:%x want:%x", i, pk.PubKeyAlgo, test.pubKeyAlgo)
}
if pk.CreationTime != test.creationTime {
t.Errorf("#%d: bad creation time got:%x want:%x", i, pk.CreationTime, test.creationTime)
if !pk.CreationTime.Equal(test.creationTime) {
t.Errorf("#%d: bad creation time got:%v want:%v", i, pk.CreationTime, test.creationTime)
}
expectedFingerprint, _ := hex.DecodeString(test.hexFingerprint)
if !bytes.Equal(expectedFingerprint, pk.Fingerprint[:]) {
......
......@@ -15,6 +15,7 @@ import (
"hash"
"io"
"strconv"
"time"
)
// Signature represents a signature. See RFC 4880, section 5.2.
......@@ -28,7 +29,7 @@ type Signature struct {
// HashTag contains the first two bytes of the hash for fast rejection
// of bad signed data.
HashTag [2]byte
CreationTime uint32 // Unix epoch time
CreationTime time.Time
RSASignature parsedMPI
DSASigR, DSASigS parsedMPI
......@@ -151,7 +152,7 @@ func parseSignatureSubpackets(sig *Signature, subpackets []byte, isHashed bool)
}
}
if sig.CreationTime == 0 {
if sig.CreationTime.IsZero() {
err = error_.StructuralError("no creation time in signature")
}
......@@ -223,7 +224,12 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r
err = error_.StructuralError("signature creation time not four bytes")
return
}
sig.CreationTime = binary.BigEndian.Uint32(subpacket)
t := binary.BigEndian.Uint32(subpacket)
if t == 0 {
sig.CreationTime = time.Time{}
} else {
sig.CreationTime = time.Unix(int64(t), 0)
}
case signatureExpirationSubpacket:
// Signature expiration time, section 5.2.3.10
if !isHashed {
......@@ -417,7 +423,7 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err error) {
}
h.Write(sig.HashSuffix)
digest = h.Sum()
digest = h.Sum(nil)
copy(sig.HashTag[:], digest)
return
}
......@@ -541,10 +547,7 @@ type outputSubpacket struct {
func (sig *Signature) buildSubpackets() (subpackets []outputSubpacket) {
creationTime := make([]byte, 4)
creationTime[0] = byte(sig.CreationTime >> 24)
creationTime[1] = byte(sig.CreationTime >> 16)
creationTime[2] = byte(sig.CreationTime >> 8)
creationTime[3] = byte(sig.CreationTime)
binary.BigEndian.PutUint32(creationTime, uint32(sig.CreationTime.Unix()))
subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime})
if sig.IssuerKeyId != nil {
......
......@@ -201,7 +201,7 @@ func (ser *seMDCReader) Close() error {
}
ser.h.Write(ser.trailer[:2])
final := ser.h.Sum()
final := ser.h.Sum(nil)
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error_.SignatureError("hash mismatch")
}
......@@ -227,7 +227,7 @@ func (w *seMDCWriter) Close() (err error) {
buf[0] = mdcPacketTagByte
buf[1] = sha1.Size
w.h.Write(buf[:2])
digest := w.h.Sum()
digest := w.h.Sum(nil)
copy(buf[2:], digest)
_, err = w.w.Write(buf[:])
......
......@@ -34,7 +34,7 @@ func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
}
h.Write(salt)
h.Write(in)
n := copy(out[done:], h.Sum())
n := copy(out[done:], h.Sum(nil))
done += n
}
}
......@@ -68,7 +68,7 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
written += len(combined)
}
}
n := copy(out[done:], h.Sum())
n := copy(out[done:], h.Sum(nil))
done += n
}
}
......
......@@ -68,7 +68,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
sig.SigType = sigType
sig.PubKeyAlgo = signer.PrivateKey.PubKeyAlgo
sig.Hash = crypto.SHA256
sig.CreationTime = uint32(time.Seconds())
sig.CreationTime = time.Now()
sig.IssuerKeyId = &signer.PrivateKey.KeyId
h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
......@@ -95,8 +95,8 @@ type FileHints struct {
// file should not be written to disk. It may be equal to "_CONSOLE" to
// suggest the data should not be written to disk.
FileName string
// EpochSeconds contains the modification time of the file, or 0 if not applicable.
EpochSeconds uint32
// ModTime contains the modification time of the file, or the zero time if not applicable.
ModTime time.Time
}
// SymmetricallyEncrypt acts like gpg -c: it encrypts a file with a passphrase.
......@@ -115,7 +115,11 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi
if err != nil {
return
}
return packet.SerializeLiteral(w, hints.IsBinary, hints.FileName, hints.EpochSeconds)
var epochSeconds uint32
if !hints.ModTime.IsZero() {
epochSeconds = uint32(hints.ModTime.Unix())
}
return packet.SerializeLiteral(w, hints.IsBinary, hints.FileName, epochSeconds)
}
// intersectPreferences mutates and returns a prefix of a that contains only
......@@ -243,7 +247,11 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
w = noOpCloser{encryptedData}
}
literalData, err := packet.SerializeLiteral(w, hints.IsBinary, hints.FileName, hints.EpochSeconds)
var epochSeconds uint32
if !hints.ModTime.IsZero() {
epochSeconds = uint32(hints.ModTime.Unix())
}
literalData, err := packet.SerializeLiteral(w, hints.IsBinary, hints.FileName, epochSeconds)
if err != nil {
return nil, err
}
......@@ -275,7 +283,7 @@ func (s signatureWriter) Close() error {
SigType: packet.SigTypeBinary,
PubKeyAlgo: s.signer.PubKeyAlgo,
Hash: s.hashType,
CreationTime: uint32(time.Seconds()),
CreationTime: time.Now(),
IssuerKeyId: &s.signer.KeyId,
}
......
......@@ -54,7 +54,7 @@ func TestNewEntity(t *testing.T) {
return
}
e, err := NewEntity(rand.Reader, time.Seconds(), "Test User", "test", "test@example.com")
e, err := NewEntity(rand.Reader, time.Now(), "Test User", "test", "test@example.com")
if err != nil {
t.Errorf("failed to create entity: %s", err)
return
......
......@@ -100,7 +100,7 @@ func (r *reader) Read(b []byte) (n int, err error) {
// t = encrypt(time)
// dst = encrypt(t^seed)
// seed = encrypt(t^dst)
ns := time.Nanoseconds()
ns := time.Now().UnixNano()
r.time[0] = byte(ns >> 56)
r.time[1] = byte(ns >> 48)
r.time[2] = byte(ns >> 40)
......
......@@ -81,7 +81,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -107,11 +107,11 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 20)
j := 0
for _, s := range d.s {
p[j], p[j+1], p[j+2], p[j+3] = byte(s), byte(s>>8), byte(s>>16), byte(s>>24)
j += 4
in = append(in, byte(s))
in = append(in, byte(s>>8))
in = append(in, byte(s>>16))
in = append(in, byte(s>>24))
}
return p
return in
}
......@@ -38,10 +38,10 @@ func TestVectors(t *testing.T) {
io.WriteString(md, tv.in)
} else {
io.WriteString(md, tv.in[0:len(tv.in)/2])
md.Sum()
md.Sum(nil)
io.WriteString(md, tv.in[len(tv.in)/2:])
}
s := fmt.Sprintf("%x", md.Sum())
s := fmt.Sprintf("%x", md.Sum(nil))
if s != tv.out {
t.Fatalf("RIPEMD-160[%d](%s) = %s, expected %s", j, tv.in, s, tv.out)
}
......@@ -56,7 +56,7 @@ func TestMillionA(t *testing.T) {
io.WriteString(md, "aaaaaaaaaa")
}
out := "52783243c1697bdbe16d37f97f68f08325dc1528"
s := fmt.Sprintf("%x", md.Sum())
s := fmt.Sprintf("%x", md.Sum(nil))
if s != out {
t.Fatalf("RIPEMD-160 (1 million 'a') = %s, expected %s", s, out)
}
......
......@@ -168,7 +168,7 @@ func TestSignPKCS1v15(t *testing.T) {
for i, test := range signPKCS1v15Tests {
h := sha1.New()
h.Write([]byte(test.in))
digest := h.Sum()
digest := h.Sum(nil)
s, err := SignPKCS1v15(nil, rsaPrivateKey, crypto.SHA1, digest)
if err != nil {
......@@ -186,7 +186,7 @@ func TestVerifyPKCS1v15(t *testing.T) {
for i, test := range signPKCS1v15Tests {
h := sha1.New()
h.Write([]byte(test.in))
digest := h.Sum()
digest := h.Sum(nil)
sig, _ := hex.DecodeString(test.out)
......
......@@ -194,7 +194,7 @@ func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
for done < len(out) {
hash.Write(seed)
hash.Write(counter[0:4])
digest := hash.Sum()
digest := hash.Sum(nil)
hash.Reset()
for i := 0; i < len(digest) && done < len(out); i++ {
......@@ -231,7 +231,7 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
}
hash.Write(label)
lHash := hash.Sum()
lHash := hash.Sum(nil)
hash.Reset()
em := make([]byte, k)
......@@ -428,7 +428,7 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
}
hash.Write(label)
lHash := hash.Sum()
lHash := hash.Sum(nil)
hash.Reset()
// Converting the plaintext number to bytes will strip any
......
......@@ -79,7 +79,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -105,14 +105,11 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 20)
j := 0
for _, s := range d.h {
p[j+0] = byte(s >> 24)
p[j+1] = byte(s >> 16)
p[j+2] = byte(s >> 8)
p[j+3] = byte(s >> 0)
j += 4
in = append(in, byte(s>>24))
in = append(in, byte(s>>16))
in = append(in, byte(s>>8))
in = append(in, byte(s))
}
return p
return in
}
......@@ -60,10 +60,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sha1[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......
......@@ -123,7 +123,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -149,17 +149,15 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 32)
j := 0
for _, s := range d.h {
p[j+0] = byte(s >> 24)
p[j+1] = byte(s >> 16)
p[j+2] = byte(s >> 8)
p[j+3] = byte(s >> 0)
j += 4
}
h := d.h[:]
if d.is224 {
return p[0:28]
h = d.h[:7]
}
for _, s := range h {
in = append(in, byte(s>>24))
in = append(in, byte(s>>16))
in = append(in, byte(s>>8))
in = append(in, byte(s))
}
return p
return in
}
......@@ -94,10 +94,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sha256[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......@@ -112,10 +112,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sha224[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......
......@@ -123,7 +123,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
return
}
func (d0 *digest) Sum() []byte {
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := new(digest)
*d = *d0
......@@ -149,21 +149,19 @@ func (d0 *digest) Sum() []byte {
panic("d.nx != 0")
}
p := make([]byte, 64)
j := 0
for _, s := range d.h {
p[j+0] = byte(s >> 56)
p[j+1] = byte(s >> 48)
p[j+2] = byte(s >> 40)
p[j+3] = byte(s >> 32)
p[j+4] = byte(s >> 24)
p[j+5] = byte(s >> 16)
p[j+6] = byte(s >> 8)
p[j+7] = byte(s >> 0)
j += 8
}
h := d.h[:]
if d.is384 {
return p[0:48]
h = d.h[:6]
}
for _, s := range h {
in = append(in, byte(s>>56))
in = append(in, byte(s>>48))
in = append(in, byte(s>>40))
in = append(in, byte(s>>32))
in = append(in, byte(s>>24))
in = append(in, byte(s>>16))
in = append(in, byte(s>>8))
in = append(in, byte(s))
}
return p
return in
}
......@@ -94,10 +94,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sha512[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......@@ -112,10 +112,10 @@ func TestGolden(t *testing.T) {
io.WriteString(c, g.in)
} else {
io.WriteString(c, g.in[0:len(g.in)/2])
c.Sum()
c.Sum(nil)
io.WriteString(c, g.in[len(g.in)/2:])
}
s := fmt.Sprintf("%x", c.Sum())
s := fmt.Sprintf("%x", c.Sum(nil))
if s != g.out {
t.Fatalf("sha384[%d](%s) = %s want %s", j, g.in, s, g.out)
}
......
......@@ -37,6 +37,7 @@ type keyAgreement interface {
// A cipherSuite is a specific combination of key agreement, cipher and MAC
// function. All cipher suites currently assume RSA key agreement.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
......@@ -50,13 +51,13 @@ type cipherSuite struct {
mac func(version uint16, macKey []byte) macFunction
}
var cipherSuites = map[uint16]*cipherSuite{
TLS_RSA_WITH_RC4_128_SHA: &cipherSuite{16, 20, 0, rsaKA, false, cipherRC4, macSHA1},
TLS_RSA_WITH_3DES_EDE_CBC_SHA: &cipherSuite{24, 20, 8, rsaKA, false, cipher3DES, macSHA1},
TLS_RSA_WITH_AES_128_CBC_SHA: &cipherSuite{16, 20, 16, rsaKA, false, cipherAES, macSHA1},
TLS_ECDHE_RSA_WITH_RC4_128_SHA: &cipherSuite{16, 20, 0, ecdheRSAKA, true, cipherRC4, macSHA1},
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: &cipherSuite{24, 20, 8, ecdheRSAKA, true, cipher3DES, macSHA1},
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: &cipherSuite{16, 20, 16, ecdheRSAKA, true, cipherAES, macSHA1},
var cipherSuites = []*cipherSuite{
&cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, false, cipherRC4, macSHA1},
&cipherSuite{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, false, cipher3DES, macSHA1},
&cipherSuite{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, false, cipherAES, macSHA1},
&cipherSuite{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, true, cipherRC4, macSHA1},
&cipherSuite{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, true, cipher3DES, macSHA1},
&cipherSuite{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, true, cipherAES, macSHA1},
}
func cipherRC4(key, iv []byte, isRead bool) interface{} {
......@@ -126,13 +127,13 @@ func (s ssl30MAC) MAC(seq, record []byte) []byte {
s.h.Write(record[:1])
s.h.Write(record[3:5])
s.h.Write(record[recordHeaderLen:])
digest := s.h.Sum()
digest := s.h.Sum(nil)
s.h.Reset()
s.h.Write(s.key)
s.h.Write(ssl30Pad2[:padLength])
s.h.Write(digest)
return s.h.Sum()
return s.h.Sum(nil)
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3.
......@@ -148,7 +149,7 @@ func (s tls10MAC) MAC(seq, record []byte) []byte {
s.h.Reset()
s.h.Write(seq)
s.h.Write(record)
return s.h.Sum()
return s.h.Sum(nil)
}
func rsaKA() keyAgreement {
......@@ -159,15 +160,20 @@ func ecdheRSAKA() keyAgreement {
return new(ecdheRSAKeyAgreement)
}
// mutualCipherSuite returns a cipherSuite and its id given a list of supported
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) (suite *cipherSuite, id uint16) {
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuites[id], id
for _, suite := range cipherSuites {
if suite.id == want {
return suite
}
}
return nil
}
}
return
return nil
}
// A list of the possible cipher suite ids. Taken from
......
......@@ -121,7 +121,7 @@ type Config struct {
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds.
Time func() int64
Time func() time.Time
// Certificates contains one or more certificate chains
// to present to the other side of the connection.
......@@ -175,10 +175,10 @@ func (c *Config) rand() io.Reader {
return r
}
func (c *Config) time() int64 {
func (c *Config) time() time.Time {
t := c.Time
if t == nil {
t = time.Seconds
t = time.Now
}
return t()
}
......@@ -315,9 +315,7 @@ var (
func initDefaultCipherSuites() {
varDefaultCipherSuites = make([]uint16, len(cipherSuites))
i := 0
for id := range cipherSuites {
varDefaultCipherSuites[i] = id
i++
for i, suite := range cipherSuites {
varDefaultCipherSuites[i] = suite.id
}
}
......@@ -32,7 +32,7 @@ func (c *Conn) clientHandshake() error {
nextProtoNeg: len(c.config.NextProtos) > 0,
}
t := uint32(c.config.time())
t := uint32(c.config.time().Unix())
hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16)
hello.random[2] = byte(t >> 8)
......@@ -72,7 +72,7 @@ func (c *Conn) clientHandshake() error {
return errors.New("server advertised unrequested NPN")
}
suite, suiteId := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
if suite == nil {
return c.sendAlert(alertHandshakeFailure)
}
......@@ -232,8 +232,8 @@ func (c *Conn) clientHandshake() error {
if cert != nil {
certVerify := new(certificateVerifyMsg)
var digest [36]byte
copy(digest[0:16], finishedHash.serverMD5.Sum())
copy(digest[16:36], finishedHash.serverSHA1.Sum())
copy(digest[0:16], finishedHash.serverMD5.Sum(nil))
copy(digest[16:36], finishedHash.serverSHA1.Sum(nil))
signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, crypto.MD5SHA1, digest[0:])
if err != nil {
return c.sendAlert(alertInternalError)
......@@ -292,7 +292,7 @@ func (c *Conn) clientHandshake() error {
}
c.handshakeComplete = true
c.cipherSuite = suiteId
c.cipherSuite = suite.id
return nil
}
......
......@@ -56,18 +56,25 @@ Curves:
ellipticOk := supportedCurve && supportedPointFormat
var suite *cipherSuite
var suiteId uint16
FindCipherSuite:
for _, id := range clientHello.cipherSuites {
for _, supported := range config.cipherSuites() {
if id == supported {
suite = cipherSuites[id]
suite = nil
for _, s := range cipherSuites {
if s.id == id {
suite = s
break
}
}
if suite == nil {
continue
}
// Don't select a ciphersuite which we can't
// support for this client.
if suite.elliptic && !ellipticOk {
continue
}
suiteId = id
break FindCipherSuite
}
}
......@@ -87,8 +94,8 @@ FindCipherSuite:
}
hello.vers = vers
hello.cipherSuite = suiteId
t := uint32(config.time())
hello.cipherSuite = suite.id
t := uint32(config.time().Unix())
hello.random = make([]byte, 32)
hello.random[0] = byte(t >> 24)
hello.random[1] = byte(t >> 16)
......@@ -228,8 +235,8 @@ FindCipherSuite:
}
digest := make([]byte, 36)
copy(digest[0:16], finishedHash.serverMD5.Sum())
copy(digest[16:36], finishedHash.serverSHA1.Sum())
copy(digest[0:16], finishedHash.serverMD5.Sum(nil))
copy(digest[16:36], finishedHash.serverSHA1.Sum(nil))
err = rsa.VerifyPKCS1v15(pub, crypto.MD5SHA1, digest, certVerify.signature)
if err != nil {
c.sendAlert(alertBadCertificate)
......@@ -296,7 +303,7 @@ FindCipherSuite:
c.writeRecord(recordTypeHandshake, finished.marshal())
c.handshakeComplete = true
c.cipherSuite = suiteId
c.cipherSuite = suite.id
return nil
}
......@@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"testing"
"time"
)
type zeroSource struct{}
......@@ -31,7 +32,7 @@ var testConfig *Config
func init() {
testConfig = new(Config)
testConfig.Time = func() int64 { return 0 }
testConfig.Time = func() time.Time { return time.Unix(0, 0) }
testConfig.Rand = zeroSource{}
testConfig.Certificates = make([]Certificate, 1)
testConfig.Certificates[0].Certificate = [][]byte{testCertificate}
......
......@@ -90,13 +90,13 @@ func md5SHA1Hash(slices ...[]byte) []byte {
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum())
copy(md5sha1, hmd5.Sum(nil))
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
copy(md5sha1[md5.Size:], hsha1.Sum())
copy(md5sha1[md5.Size:], hsha1.Sum(nil))
return md5sha1
}
......
......@@ -22,14 +22,14 @@ func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum()
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum()
b := h.Sum(nil)
todo := len(b)
if j+todo > len(result) {
todo = len(result) - j
......@@ -39,7 +39,7 @@ func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h.Reset()
h.Write(a)
a = h.Sum()
a = h.Sum(nil)
}
}
......@@ -84,13 +84,13 @@ func pRF30(result, secret, label, seed []byte) {
hashSHA1.Write(b[:i+1])
hashSHA1.Write(secret)
hashSHA1.Write(seed)
digest := hashSHA1.Sum()
digest := hashSHA1.Sum(nil)
hashMD5.Reset()
hashMD5.Write(secret)
hashMD5.Write(digest)
done += copy(result[done:], hashMD5.Sum())
done += copy(result[done:], hashMD5.Sum(nil))
i++
}
}
......@@ -182,24 +182,24 @@ func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []by
md5.Write(magic[:])
md5.Write(masterSecret)
md5.Write(ssl30Pad1[:])
md5Digest := md5.Sum()
md5Digest := md5.Sum(nil)
md5.Reset()
md5.Write(masterSecret)
md5.Write(ssl30Pad2[:])
md5.Write(md5Digest)
md5Digest = md5.Sum()
md5Digest = md5.Sum(nil)
sha1.Write(magic[:])
sha1.Write(masterSecret)
sha1.Write(ssl30Pad1[:40])
sha1Digest := sha1.Sum()
sha1Digest := sha1.Sum(nil)
sha1.Reset()
sha1.Write(masterSecret)
sha1.Write(ssl30Pad2[:40])
sha1.Write(sha1Digest)
sha1Digest = sha1.Sum()
sha1Digest = sha1.Sum(nil)
ret := make([]byte, len(md5Digest)+len(sha1Digest))
copy(ret, md5Digest)
......@@ -217,8 +217,8 @@ func (h finishedHash) clientSum(masterSecret []byte) []byte {
return finishedSum30(h.clientMD5, h.clientSHA1, masterSecret, ssl3ClientFinishedMagic)
}
md5 := h.clientMD5.Sum()
sha1 := h.clientSHA1.Sum()
md5 := h.clientMD5.Sum(nil)
sha1 := h.clientSHA1.Sum(nil)
return finishedSum10(md5, sha1, clientFinishedLabel, masterSecret)
}
......@@ -229,7 +229,7 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte {
return finishedSum30(h.serverMD5, h.serverSHA1, masterSecret, ssl3ServerFinishedMagic)
}
md5 := h.serverMD5.Sum()
sha1 := h.serverSHA1.Sum()
md5 := h.serverMD5.Sum(nil)
sha1 := h.serverSHA1.Sum(nil)
return finishedSum10(md5, sha1, serverFinishedLabel, masterSecret)
}
......@@ -14,6 +14,7 @@ var certFiles = []string{
"/etc/ssl/certs/ca-certificates.crt", // Linux etc
"/etc/pki/tls/certs/ca-bundle.crt", // Fedora/RHEL
"/etc/ssl/ca-bundle.pem", // OpenSUSE
"/etc/ssl/cert.pem", // OpenBSD
}
func initDefaultRoots() {
......
......@@ -6,7 +6,6 @@ package tls
import (
"crypto/x509"
"reflect"
"syscall"
"unsafe"
)
......@@ -16,29 +15,23 @@ func loadStore(roots *x509.CertPool, name string) {
if err != nil {
return
}
defer syscall.CertCloseStore(store, 0)
var cert *syscall.CertContext
for {
cert = syscall.CertEnumCertificatesInStore(store, cert)
if cert == nil {
break
cert, err = syscall.CertEnumCertificatesInStore(store, cert)
if err != nil {
return
}
var asn1Slice []byte
hdrp := (*reflect.SliceHeader)(unsafe.Pointer(&asn1Slice))
hdrp.Data = cert.EncodedCert
hdrp.Len = int(cert.Length)
hdrp.Cap = int(cert.Length)
buf := make([]byte, len(asn1Slice))
copy(buf, asn1Slice)
if cert, err := x509.ParseCertificate(buf); err == nil {
roots.AddCert(cert)
buf := (*[1 << 20]byte)(unsafe.Pointer(cert.EncodedCert))[:]
// ParseCertificate requires its own copy of certificate data to keep.
buf2 := make([]byte, cert.Length)
copy(buf2, buf)
if c, err := x509.ParseCertificate(buf2); err == nil {
roots.AddCert(c)
}
}
syscall.CertCloseStore(store, 0)
}
func initDefaultRoots() {
......
......@@ -157,10 +157,21 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error)
return
}
key, err := x509.ParsePKCS1PrivateKey(keyDERBlock.Bytes)
if err != nil {
err = errors.New("crypto/tls: failed to parse key: " + err.Error())
return
// OpenSSL 0.9.8 generates PKCS#1 private keys by default, while
// OpenSSL 1.0.0 generates PKCS#8 keys. We try both.
var key *rsa.PrivateKey
if key, err = x509.ParsePKCS1PrivateKey(keyDERBlock.Bytes); err != nil {
var privKey interface{}
if privKey, err = x509.ParsePKCS8PrivateKey(keyDERBlock.Bytes); err != nil {
err = errors.New("crypto/tls: failed to parse key: " + err.Error())
return
}
var ok bool
if key, ok = privKey.(*rsa.PrivateKey); !ok {
err = errors.New("crypto/tls: found non-RSA private key in PKCS#8 wrapping")
return
}
}
cert.PrivateKey = key
......
......@@ -8,7 +8,7 @@ import (
"encoding/pem"
)
// Roots is a set of certificates.
// CertPool is a set of certificates.
type CertPool struct {
bySubjectKeyId map[string][]int
byName map[string][]int
......@@ -70,11 +70,11 @@ func (s *CertPool) AddCert(cert *Certificate) {
s.byName[name] = append(s.byName[name], n)
}
// AppendCertsFromPEM attempts to parse a series of PEM encoded root
// certificates. It appends any certificates found to s and returns true if any
// certificates were successfully parsed.
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
// It appends any certificates found to s and returns true if any certificates
// were successfully parsed.
//
// On many Linux systems, /etc/ssl/cert.pem will contains the system wide set
// On many Linux systems, /etc/ssl/cert.pem will contain the system wide set
// of root CAs in a format suitable for this function.
func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
for len(pemCerts) > 0 {
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
)
// pkcs8 reflects an ASN.1, PKCS#8 PrivateKey. See
// ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-8/pkcs-8v1_2.asn.
type pkcs8 struct {
Version int
Algo pkix.AlgorithmIdentifier
PrivateKey []byte
// optional attributes omitted.
}
// ParsePKCS8PrivateKey parses an unencrypted, PKCS#8 private key. See
// http://www.rsa.com/rsalabs/node.asp?id=2130
func ParsePKCS8PrivateKey(der []byte) (key interface{}, err error) {
var privKey pkcs8
if _, err := asn1.Unmarshal(der, &privKey); err != nil {
return nil, err
}
switch {
case privKey.Algo.Algorithm.Equal(oidRSA):
key, err = ParsePKCS1PrivateKey(privKey.PrivateKey)
if err != nil {
return nil, errors.New("crypto/x509: failed to parse RSA private key embedded in PKCS#8: " + err.Error())
}
return key, nil
default:
return nil, fmt.Errorf("crypto/x509: PKCS#8 wrapping contained private key with unknown algorithm: %v", privKey.Algo.Algorithm)
}
panic("unreachable")
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x509
import (
"encoding/hex"
"testing"
)
var pkcs8PrivateKeyHex = `30820278020100300d06092a864886f70d0101010500048202623082025e02010002818100cfb1b5bf9685ffa97b4f99df4ff122b70e59ac9b992f3bc2b3dde17d53c1a34928719b02e8fd17839499bfbd515bd6ef99c7a1c47a239718fe36bfd824c0d96060084b5f67f0273443007a24dfaf5634f7772c9346e10eb294c2306671a5a5e719ae24b4de467291bc571014b0e02dec04534d66a9bb171d644b66b091780e8d020301000102818100b595778383c4afdbab95d2bfed12b3f93bb0a73a7ad952f44d7185fd9ec6c34de8f03a48770f2009c8580bcd275e9632714e9a5e3f32f29dc55474b2329ff0ebc08b3ffcb35bc96e6516b483df80a4a59cceb71918cbabf91564e64a39d7e35dce21cb3031824fdbc845dba6458852ec16af5dddf51a8397a8797ae0337b1439024100ea0eb1b914158c70db39031dd8904d6f18f408c85fbbc592d7d20dee7986969efbda081fdf8bc40e1b1336d6b638110c836bfdc3f314560d2e49cd4fbde1e20b024100e32a4e793b574c9c4a94c8803db5152141e72d03de64e54ef2c8ed104988ca780cd11397bc359630d01b97ebd87067c5451ba777cf045ca23f5912f1031308c702406dfcdbbd5a57c9f85abc4edf9e9e29153507b07ce0a7ef6f52e60dcfebe1b8341babd8b789a837485da6c8d55b29bbb142ace3c24a1f5b54b454d01b51e2ad03024100bd6a2b60dee01e1b3bfcef6a2f09ed027c273cdbbaf6ba55a80f6dcc64e4509ee560f84b4f3e076bd03b11e42fe71a3fdd2dffe7e0902c8584f8cad877cdc945024100aa512fa4ada69881f1d8bb8ad6614f192b83200aef5edf4811313d5ef30a86cbd0a90f7b025c71ea06ec6b34db6306c86b1040670fd8654ad7291d066d06d031`
func TestPKCS8(t *testing.T) {
derBytes, _ := hex.DecodeString(pkcs8PrivateKeyHex)
_, err := ParsePKCS8PrivateKey(derBytes)
if err != nil {
t.Errorf("failed to decode PKCS8 key: %s", err)
}
}
......@@ -142,10 +142,9 @@ type CertificateList struct {
SignatureValue asn1.BitString
}
// HasExpired returns true iff currentTimeSeconds is past the expiry time of
// certList.
func (certList *CertificateList) HasExpired(currentTimeSeconds int64) bool {
return certList.TBSCertList.NextUpdate.Seconds() <= currentTimeSeconds
// HasExpired returns true iff now is past the expiry time of certList.
func (certList *CertificateList) HasExpired(now time.Time) bool {
return now.After(certList.TBSCertList.NextUpdate)
}
// TBSCertificateList represents the ASN.1 structure of the same name. See RFC
......@@ -155,8 +154,8 @@ type TBSCertificateList struct {
Version int `asn1:"optional,default:2"`
Signature AlgorithmIdentifier
Issuer RDNSequence
ThisUpdate *time.Time
NextUpdate *time.Time
ThisUpdate time.Time
NextUpdate time.Time
RevokedCertificates []RevokedCertificate `asn1:"optional"`
Extensions []Extension `asn1:"tag:0,optional,explicit"`
}
......@@ -165,6 +164,6 @@ type TBSCertificateList struct {
// 5280, section 5.1.
type RevokedCertificate struct {
SerialNumber *big.Int
RevocationTime *time.Time
RevocationTime time.Time
Extensions []Extension `asn1:"optional"`
}
......@@ -76,7 +76,7 @@ type VerifyOptions struct {
DNSName string
Intermediates *CertPool
Roots *CertPool
CurrentTime int64 // if 0, the current system time is used.
CurrentTime time.Time // if zero, the current time is used
}
const (
......@@ -87,8 +87,11 @@ const (
// isValid performs validity checks on the c.
func (c *Certificate) isValid(certType int, opts *VerifyOptions) error {
if opts.CurrentTime < c.NotBefore.Seconds() ||
opts.CurrentTime > c.NotAfter.Seconds() {
now := opts.CurrentTime
if now.IsZero() {
now = time.Now()
}
if now.Before(c.NotBefore) || now.After(c.NotAfter) {
return CertificateInvalidError{c, Expired}
}
......@@ -136,9 +139,6 @@ func (c *Certificate) isValid(certType int, opts *VerifyOptions) error {
//
// WARNING: this doesn't do any revocation checking.
func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err error) {
if opts.CurrentTime == 0 {
opts.CurrentTime = time.Seconds()
}
err = c.isValid(leafCertificate, &opts)
if err != nil {
return
......
......@@ -10,6 +10,7 @@ import (
"errors"
"strings"
"testing"
"time"
)
type verifyTest struct {
......@@ -133,7 +134,7 @@ func TestVerify(t *testing.T) {
Roots: NewCertPool(),
Intermediates: NewCertPool(),
DNSName: test.dnsName,
CurrentTime: test.currentTime,
CurrentTime: time.Unix(test.currentTime, 0),
}
for j, root := range test.roots {
......
......@@ -107,7 +107,7 @@ type dsaSignature struct {
}
type validity struct {
NotBefore, NotAfter *time.Time
NotBefore, NotAfter time.Time
}
type publicKeyInfo struct {
......@@ -303,7 +303,7 @@ type Certificate struct {
SerialNumber *big.Int
Issuer pkix.Name
Subject pkix.Name
NotBefore, NotAfter *time.Time // Validity bounds.
NotBefore, NotAfter time.Time // Validity bounds.
KeyUsage KeyUsage
ExtKeyUsage []ExtKeyUsage // Sequence of extended key usages.
......@@ -398,7 +398,7 @@ func (c *Certificate) CheckSignature(algo SignatureAlgorithm, signed, signature
}
h.Write(signed)
digest := h.Sum()
digest := h.Sum(nil)
switch pub := c.PublicKey.(type) {
case *rsa.PublicKey:
......@@ -899,11 +899,10 @@ var (
oidRSA = []int{1, 2, 840, 113549, 1, 1, 1}
)
// CreateSelfSignedCertificate creates a new certificate based on
// a template. The following members of template are used: SerialNumber,
// Subject, NotBefore, NotAfter, KeyUsage, BasicConstraintsValid, IsCA,
// MaxPathLen, SubjectKeyId, DNSNames, PermittedDNSDomainsCritical,
// PermittedDNSDomains.
// CreateCertificate creates a new certificate based on a template. The
// following members of template are used: SerialNumber, Subject, NotBefore,
// NotAfter, KeyUsage, BasicConstraintsValid, IsCA, MaxPathLen, SubjectKeyId,
// DNSNames, PermittedDNSDomainsCritical, PermittedDNSDomains.
//
// The certificate is signed by parent. If parent is equal to template then the
// certificate is self-signed. The parameter pub is the public key of the
......@@ -958,7 +957,7 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
h := sha1.New()
h.Write(tbsCertContents)
digest := h.Sum()
digest := h.Sum(nil)
signature, err := rsa.SignPKCS1v15(rand, priv, crypto.SHA1, digest)
if err != nil {
......@@ -1006,7 +1005,7 @@ func ParseDERCRL(derBytes []byte) (certList *pkix.CertificateList, err error) {
// CreateCRL returns a DER encoded CRL, signed by this Certificate, that
// contains the given list of revoked certificates.
func (c *Certificate) CreateCRL(rand io.Reader, priv *rsa.PrivateKey, revokedCerts []pkix.RevokedCertificate, now, expiry *time.Time) (crlBytes []byte, err error) {
func (c *Certificate) CreateCRL(rand io.Reader, priv *rsa.PrivateKey, revokedCerts []pkix.RevokedCertificate, now, expiry time.Time) (crlBytes []byte, err error) {
tbsCertList := pkix.TBSCertificateList{
Version: 2,
Signature: pkix.AlgorithmIdentifier{
......@@ -1025,7 +1024,7 @@ func (c *Certificate) CreateCRL(rand io.Reader, priv *rsa.PrivateKey, revokedCer
h := sha1.New()
h.Write(tbsCertListContents)
digest := h.Sum()
digest := h.Sum(nil)
signature, err := rsa.SignPKCS1v15(rand, priv, crypto.SHA1, digest)
if err != nil {
......
......@@ -250,8 +250,8 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
CommonName: commonName,
Organization: []string{"Acme Co"},
},
NotBefore: time.SecondsToUTC(1000),
NotAfter: time.SecondsToUTC(100000),
NotBefore: time.Unix(1000, 0),
NotAfter: time.Unix(100000, 0),
SubjectKeyId: []byte{1, 2, 3, 4},
KeyUsage: KeyUsageCertSign,
......@@ -396,8 +396,8 @@ func TestCRLCreation(t *testing.T) {
block, _ = pem.Decode([]byte(pemCertificate))
cert, _ := ParseCertificate(block.Bytes)
now := time.SecondsToUTC(1000)
expiry := time.SecondsToUTC(10000)
now := time.Unix(1000, 0)
expiry := time.Unix(10000, 0)
revokedCerts := []pkix.RevokedCertificate{
{
......@@ -443,7 +443,7 @@ func TestParseDERCRL(t *testing.T) {
t.Errorf("bad number of revoked certificates. got: %d want: %d", numCerts, expected)
}
if certList.HasExpired(1302517272) {
if certList.HasExpired(time.Unix(1302517272, 0)) {
t.Errorf("CRL has expired (but shouldn't have)")
}
......@@ -463,7 +463,7 @@ func TestParsePEMCRL(t *testing.T) {
t.Errorf("bad number of revoked certificates. got: %d want: %d", numCerts, expected)
}
if certList.HasExpired(1302517272) {
if certList.HasExpired(time.Unix(1302517272, 0)) {
t.Errorf("CRL has expired (but shouldn't have)")
}
......
......@@ -247,7 +247,7 @@ func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, err error)
// UTCTime
func parseUTCTime(bytes []byte) (ret *time.Time, err error) {
func parseUTCTime(bytes []byte) (ret time.Time, err error) {
s := string(bytes)
ret, err = time.Parse("0601021504Z0700", s)
if err == nil {
......@@ -259,7 +259,7 @@ func parseUTCTime(bytes []byte) (ret *time.Time, err error) {
// parseGeneralizedTime parses the GeneralizedTime from the given byte slice
// and returns the resulting time.
func parseGeneralizedTime(bytes []byte) (ret *time.Time, err error) {
func parseGeneralizedTime(bytes []byte) (ret time.Time, err error) {
return time.Parse("20060102150405Z0700", string(bytes))
}
......@@ -450,7 +450,7 @@ var (
objectIdentifierType = reflect.TypeOf(ObjectIdentifier{})
enumeratedType = reflect.TypeOf(Enumerated(0))
flagType = reflect.TypeOf(Flag(false))
timeType = reflect.TypeOf(&time.Time{})
timeType = reflect.TypeOf(time.Time{})
rawValueType = reflect.TypeOf(RawValue{})
rawContentsType = reflect.TypeOf(RawContent(nil))
bigIntType = reflect.TypeOf(new(big.Int))
......@@ -647,7 +647,7 @@ func parseField(v reflect.Value, bytes []byte, initOffset int, params fieldParam
err = err1
return
case timeType:
var time *time.Time
var time time.Time
var err1 error
if universalTag == tagUTCTime {
time, err1 = parseUTCTime(innerBytes)
......@@ -799,7 +799,7 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
//
// An ASN.1 ENUMERATED can be written to an Enumerated.
//
// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a *time.Time.
// An ASN.1 UTCTIME or GENERALIZEDTIME can be written to a time.Time.
//
// An ASN.1 PrintableString or IA5String can be written to a string.
//
......
......@@ -202,43 +202,51 @@ func TestObjectIdentifier(t *testing.T) {
type timeTest struct {
in string
ok bool
out *time.Time
out time.Time
}
var utcTestData = []timeTest{
{"910506164540-0700", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, -7 * 60 * 60, ""}},
{"910506164540+0730", true, &time.Time{1991, 05, 06, 16, 45, 40, 0, 7*60*60 + 30*60, ""}},
{"910506234540Z", true, &time.Time{1991, 05, 06, 23, 45, 40, 0, 0, "UTC"}},
{"9105062345Z", true, &time.Time{1991, 05, 06, 23, 45, 0, 0, 0, "UTC"}},
{"a10506234540Z", false, nil},
{"91a506234540Z", false, nil},
{"9105a6234540Z", false, nil},
{"910506a34540Z", false, nil},
{"910506334a40Z", false, nil},
{"91050633444aZ", false, nil},
{"910506334461Z", false, nil},
{"910506334400Za", false, nil},
{"910506164540-0700", true, time.Date(1991, 05, 06, 16, 45, 40, 0, time.FixedZone("", -7*60*60))},
{"910506164540+0730", true, time.Date(1991, 05, 06, 16, 45, 40, 0, time.FixedZone("", 7*60*60+30*60))},
{"910506234540Z", true, time.Date(1991, 05, 06, 23, 45, 40, 0, time.UTC)},
{"9105062345Z", true, time.Date(1991, 05, 06, 23, 45, 0, 0, time.UTC)},
{"a10506234540Z", false, time.Time{}},
{"91a506234540Z", false, time.Time{}},
{"9105a6234540Z", false, time.Time{}},
{"910506a34540Z", false, time.Time{}},
{"910506334a40Z", false, time.Time{}},
{"91050633444aZ", false, time.Time{}},
{"910506334461Z", false, time.Time{}},
{"910506334400Za", false, time.Time{}},
}
func TestUTCTime(t *testing.T) {
for i, test := range utcTestData {
ret, err := parseUTCTime([]byte(test.in))
if (err == nil) != test.ok {
t.Errorf("#%d: Incorrect error result (did fail? %v, expected: %v)", i, err == nil, test.ok)
}
if err == nil {
if !reflect.DeepEqual(test.out, ret) {
t.Errorf("#%d: Bad result: %v (expected %v)", i, ret, test.out)
if err != nil {
if test.ok {
t.Errorf("#%d: parseUTCTime(%q) = error %v", i, err)
}
continue
}
if !test.ok {
t.Errorf("#%d: parseUTCTime(%q) succeeded, should have failed", i)
continue
}
const format = "Jan _2 15:04:05 -0700 2006" // ignore zone name, just offset
have := ret.Format(format)
want := test.out.Format(format)
if have != want {
t.Errorf("#%d: parseUTCTime(%q) = %s, want %s", test.in, have, want)
}
}
}
var generalizedTimeTestData = []timeTest{
{"20100102030405Z", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 0, "UTC"}},
{"20100102030405", false, nil},
{"20100102030405+0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, 6*60*60 + 7*60, ""}},
{"20100102030405-0607", true, &time.Time{2010, 01, 02, 03, 04, 05, 0, -6*60*60 - 7*60, ""}},
{"20100102030405Z", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.UTC)},
{"20100102030405", false, time.Time{}},
{"20100102030405+0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", 6*60*60+7*60))},
{"20100102030405-0607", true, time.Date(2010, 01, 02, 03, 04, 05, 0, time.FixedZone("", -6*60*60-7*60))},
}
func TestGeneralizedTime(t *testing.T) {
......@@ -407,7 +415,7 @@ type AttributeTypeAndValue struct {
}
type Validity struct {
NotBefore, NotAfter *time.Time
NotBefore, NotAfter time.Time
}
type PublicKeyInfo struct {
......@@ -475,7 +483,10 @@ var derEncodedSelfSignedCert = Certificate{
RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 3}, Value: "false.example.com"}},
RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{1, 2, 840, 113549, 1, 9, 1}, Value: "false@example.com"}},
},
Validity: Validity{NotBefore: &time.Time{Year: 2009, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, ZoneOffset: 0, Zone: "UTC"}, NotAfter: &time.Time{Year: 2010, Month: 10, Day: 8, Hour: 0, Minute: 25, Second: 53, ZoneOffset: 0, Zone: "UTC"}},
Validity: Validity{
NotBefore: time.Date(2009, 10, 8, 00, 25, 53, 0, time.UTC),
NotAfter: time.Date(2010, 10, 8, 00, 25, 53, 0, time.UTC),
},
Subject: RDNSequence{
RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 6}, Value: "XX"}},
RelativeDistinguishedNameSET{AttributeTypeAndValue{Type: ObjectIdentifier{2, 5, 4, 8}, Value: "Some-State"}},
......
......@@ -288,52 +288,58 @@ func marshalTwoDigits(out *forkableWriter, v int) (err error) {
return out.WriteByte(byte('0' + v%10))
}
func marshalUTCTime(out *forkableWriter, t *time.Time) (err error) {
func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
utc := t.UTC()
year, month, day := utc.Date()
switch {
case 1950 <= t.Year && t.Year < 2000:
err = marshalTwoDigits(out, int(t.Year-1900))
case 2000 <= t.Year && t.Year < 2050:
err = marshalTwoDigits(out, int(t.Year-2000))
case 1950 <= year && year < 2000:
err = marshalTwoDigits(out, int(year-1900))
case 2000 <= year && year < 2050:
err = marshalTwoDigits(out, int(year-2000))
default:
return StructuralError{"Cannot represent time as UTCTime"}
}
if err != nil {
return
}
err = marshalTwoDigits(out, t.Month)
err = marshalTwoDigits(out, int(month))
if err != nil {
return
}
err = marshalTwoDigits(out, t.Day)
err = marshalTwoDigits(out, day)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Hour)
hour, min, sec := utc.Clock()
err = marshalTwoDigits(out, hour)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Minute)
err = marshalTwoDigits(out, min)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Second)
err = marshalTwoDigits(out, sec)
if err != nil {
return
}
_, offset := t.Zone()
switch {
case t.ZoneOffset/60 == 0:
case offset/60 == 0:
err = out.WriteByte('Z')
return
case t.ZoneOffset > 0:
case offset > 0:
err = out.WriteByte('+')
case t.ZoneOffset < 0:
case offset < 0:
err = out.WriteByte('-')
}
......@@ -341,7 +347,7 @@ func marshalUTCTime(out *forkableWriter, t *time.Time) (err error) {
return
}
offsetMinutes := t.ZoneOffset / 60
offsetMinutes := offset / 60
if offsetMinutes < 0 {
offsetMinutes = -offsetMinutes
}
......@@ -366,7 +372,7 @@ func stripTagAndLength(in []byte) []byte {
func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
switch value.Type() {
case timeType:
return marshalUTCTime(out, value.Interface().(*time.Time))
return marshalUTCTime(out, value.Interface().(time.Time))
case bitStringType:
return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType:
......
......@@ -51,10 +51,7 @@ type optionalRawValueTest struct {
type testSET []int
func setPST(t *time.Time) *time.Time {
t.ZoneOffset = -28800
return t
}
var PST = time.FixedZone("PST", -8*60*60)
type marshalTest struct {
in interface{}
......@@ -73,9 +70,9 @@ var marshalTests = []marshalTest{
{[]byte{1, 2, 3}, "0403010203"},
{implicitTagTest{64}, "3003850140"},
{explicitTagTest{64}, "3005a503020140"},
{time.SecondsToUTC(0), "170d3730303130313030303030305a"},
{time.SecondsToUTC(1258325776), "170d3039313131353232353631365a"},
{setPST(time.SecondsToUTC(1258325776)), "17113039313131353232353631362d30383030"},
{time.Unix(0, 0).UTC(), "170d3730303130313030303030305a"},
{time.Unix(1258325776, 0).UTC(), "170d3039313131353232353631365a"},
{time.Unix(1258325776, 0).In(PST), "17113039313131353232353631362d30383030"},
{BitString{[]byte{0x80}, 1}, "03020780"},
{BitString{[]byte{0x81, 0xf0}, 12}, "03030481f0"},
{ObjectIdentifier([]int{1, 2, 3, 4}), "06032a0304"},
......@@ -123,7 +120,8 @@ func TestMarshal(t *testing.T) {
}
out, _ := hex.DecodeString(test.out)
if bytes.Compare(out, data) != 0 {
t.Errorf("#%d got: %x want %x", i, data, out)
t.Errorf("#%d got: %x want %x\n\t%q\n\t%q", i, data, out, data, out)
}
}
}
......@@ -16,6 +16,7 @@ import (
"runtime"
"sort"
"strconv"
"sync"
"unicode"
"unicode/utf8"
)
......@@ -295,28 +296,10 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
case reflect.Struct:
e.WriteByte('{')
t := v.Type()
n := v.NumField()
first := true
for i := 0; i < n; i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
tag, omitEmpty, quoted := f.Name, false, false
if tv := f.Tag.Get("json"); tv != "" {
if tv == "-" {
continue
}
name, opts := parseTag(tv)
if isValidTag(name) {
tag = name
}
omitEmpty = opts.Contains("omitempty")
quoted = opts.Contains("string")
}
fieldValue := v.Field(i)
if omitEmpty && isEmptyValue(fieldValue) {
for _, ef := range encodeFields(v.Type()) {
fieldValue := v.Field(ef.i)
if ef.omitEmpty && isEmptyValue(fieldValue) {
continue
}
if first {
......@@ -324,9 +307,9 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
} else {
e.WriteByte(',')
}
e.string(tag)
e.string(ef.tag)
e.WriteByte(':')
e.reflectValueQuoted(fieldValue, quoted)
e.reflectValueQuoted(fieldValue, ef.quoted)
}
e.WriteByte('}')
......@@ -470,3 +453,63 @@ func (e *encodeState) string(s string) (int, error) {
e.WriteByte('"')
return e.Len() - len0, nil
}
// encodeField contains information about how to encode a field of a
// struct.
type encodeField struct {
i int // field index in struct
tag string
quoted bool
omitEmpty bool
}
var (
typeCacheLock sync.RWMutex
encodeFieldsCache = make(map[reflect.Type][]encodeField)
)
// encodeFields returns a slice of encodeField for a given
// struct type.
func encodeFields(t reflect.Type) []encodeField {
typeCacheLock.RLock()
fs, ok := encodeFieldsCache[t]
typeCacheLock.RUnlock()
if ok {
return fs
}
typeCacheLock.Lock()
defer typeCacheLock.Unlock()
fs, ok = encodeFieldsCache[t]
if ok {
return fs
}
v := reflect.Zero(t)
n := v.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
var ef encodeField
ef.i = i
ef.tag = f.Name
tv := f.Tag.Get("json")
if tv != "" {
if tv == "-" {
continue
}
name, opts := parseTag(tv)
if isValidTag(name) {
ef.tag = name
}
ef.omitEmpty = opts.Contains("omitempty")
ef.quoted = opts.Contains("string")
}
fs = append(fs, ef)
}
encodeFieldsCache[t] = fs
return fs
}
......@@ -61,7 +61,7 @@ type StartElement struct {
func (e StartElement) Copy() StartElement {
attrs := make([]Attr, len(e.Attr))
copy(e.Attr, attrs)
copy(attrs, e.Attr)
e.Attr = attrs
return e
}
......
......@@ -29,71 +29,69 @@ const testInput = `
</body><!-- missing final newline -->`
var rawTokens = []Token{
CharData([]byte("\n")),
CharData("\n"),
ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)},
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
CharData("\n"),
Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`),
),
CharData([]byte("\n")),
CharData("\n"),
StartElement{Name{"", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}},
CharData([]byte("World <>'\" 白鵬翔")),
CharData("World <>'\" 白鵬翔"),
EndElement{Name{"", "hello"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"", "inner"}, []Attr{}},
EndElement{Name{"", "inner"}},
CharData([]byte("\n ")),
CharData("\n "),
EndElement{Name{"", "outer"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"tag", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
CharData("\n "),
CharData("Some text here."),
CharData("\n "),
EndElement{Name{"tag", "name"}},
CharData([]byte("\n")),
CharData("\n"),
EndElement{Name{"", "body"}},
Comment([]byte(" missing final newline ")),
Comment(" missing final newline "),
}
var cookedTokens = []Token{
CharData([]byte("\n")),
CharData("\n"),
ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)},
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
CharData("\n"),
Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`),
),
CharData([]byte("\n")),
CharData("\n"),
StartElement{Name{"ns2", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"ns2", "hello"}, []Attr{{Name{"", "lang"}, "en"}}},
CharData([]byte("World <>'\" 白鵬翔")),
CharData("World <>'\" 白鵬翔"),
EndElement{Name{"ns2", "hello"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"ns2", "inner"}, []Attr{}},
EndElement{Name{"ns2", "inner"}},
CharData([]byte("\n ")),
CharData("\n "),
EndElement{Name{"ns2", "outer"}},
CharData([]byte("\n ")),
CharData("\n "),
StartElement{Name{"ns3", "name"}, []Attr{}},
CharData([]byte("\n ")),
CharData([]byte("Some text here.")),
CharData([]byte("\n ")),
CharData("\n "),
CharData("Some text here."),
CharData("\n "),
EndElement{Name{"ns3", "name"}},
CharData([]byte("\n")),
CharData("\n"),
EndElement{Name{"ns2", "body"}},
Comment([]byte(" missing final newline ")),
Comment(" missing final newline "),
}
const testInputAltEncoding = `
......@@ -101,11 +99,11 @@ const testInputAltEncoding = `
<TAG>VALUE</TAG>`
var rawTokensAltEncoding = []Token{
CharData([]byte("\n")),
CharData("\n"),
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
CharData([]byte("\n")),
CharData("\n"),
StartElement{Name{"", "tag"}, []Attr{}},
CharData([]byte("value")),
CharData("value"),
EndElement{Name{"", "tag"}},
}
......@@ -270,21 +268,21 @@ var nestedDirectivesInput = `
`
var nestedDirectivesTokens = []Token{
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt ">">]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt "<">]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt '>'>]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt '<'>]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt '">'>]`)),
CharData([]byte("\n")),
Directive([]byte(`DOCTYPE [<!ENTITY xlt "'<">]`)),
CharData([]byte("\n")),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY rdf "http://www.w3.org/1999/02/22-rdf-syntax-ns#">]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt ">">]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt "<">]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt '>'>]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt '<'>]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt '">'>]`),
CharData("\n"),
Directive(`DOCTYPE [<!ENTITY xlt "'<">]`),
CharData("\n"),
}
func TestNestedDirectives(t *testing.T) {
......@@ -488,10 +486,13 @@ func TestCopyTokenStartElement(t *testing.T) {
elt := StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}}
var tok1 Token = elt
tok2 := CopyToken(tok1)
if tok1.(StartElement).Attr[0].Value != "en" {
t.Error("CopyToken overwrote Attr[0]")
}
if !reflect.DeepEqual(tok1, tok2) {
t.Error("CopyToken(StartElement) != StartElement")
}
elt.Attr[0] = Attr{Name{"", "lang"}, "de"}
tok1.(StartElement).Attr[0] = Attr{Name{"", "lang"}, "de"}
if reflect.DeepEqual(tok1, tok2) {
t.Error("CopyToken(CharData) uses same buffer.")
}
......
......@@ -150,15 +150,15 @@ func processFiles(filenames []string, allFiles bool) {
switch info, err := os.Stat(filename); {
case err != nil:
report(err)
case info.IsRegular():
if allFiles || isGoFilename(info.Name) {
filenames[i] = filename
i++
}
case info.IsDirectory():
case info.IsDir():
if allFiles || *recursive {
processDirectory(filename)
}
default:
if allFiles || isGoFilename(info.Name()) {
filenames[i] = filename
i++
}
}
}
fset := token.NewFileSet()
......
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gui defines a basic graphical user interface programming model.
package gui
import (
"image"
"image/draw"
)
// A Window represents a single graphics window.
type Window interface {
// Screen returns an editable Image for the window.
Screen() draw.Image
// FlushImage flushes changes made to Screen() back to screen.
FlushImage()
// EventChan returns a channel carrying UI events such as key presses,
// mouse movements and window resizes.
EventChan() <-chan interface{}
// Close closes the window.
Close() error
}
// A KeyEvent is sent for a key press or release.
type KeyEvent struct {
// The value k represents key k being pressed.
// The value -k represents key k being released.
// The specific set of key values is not specified,
// but ordinary characters represent themselves.
Key int
}
// A MouseEvent is sent for a button press or release or for a mouse movement.
type MouseEvent struct {
// Buttons is a bit mask of buttons: 1<<0 is left, 1<<1 middle, 1<<2 right.
// It represents button state and not necessarily the state delta: bit 0
// being on means that the left mouse button is down, but does not imply
// that the same button was up in the previous MouseEvent.
Buttons int
// Loc is the location of the cursor.
Loc image.Point
// Nsec is the event's timestamp.
Nsec int64
}
// A ConfigEvent is sent each time the window's color model or size changes.
// The client should respond by calling Window.Screen to obtain a new image.
type ConfigEvent struct {
Config image.Config
}
// An ErrEvent is sent when an error occurs.
type ErrEvent struct {
Err error
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package x11
import (
"bufio"
"errors"
"io"
"os"
)
// readU16BE reads a big-endian uint16 from r, using b as a scratch buffer.
func readU16BE(r io.Reader, b []byte) (uint16, error) {
_, err := io.ReadFull(r, b[0:2])
if err != nil {
return 0, err
}
return uint16(b[0])<<8 + uint16(b[1]), nil
}
// readStr reads a length-prefixed string from r, using b as a scratch buffer.
func readStr(r io.Reader, b []byte) (string, error) {
n, err := readU16BE(r, b)
if err != nil {
return "", err
}
if int(n) > len(b) {
return "", errors.New("Xauthority entry too long for buffer")
}
_, err = io.ReadFull(r, b[0:n])
if err != nil {
return "", err
}
return string(b[0:n]), nil
}
// readAuth reads the X authority file and returns the name/data pair for the display.
// displayStr is the "12" out of a $DISPLAY like ":12.0".
func readAuth(displayStr string) (name, data string, err error) {
// b is a scratch buffer to use and should be at least 256 bytes long
// (i.e. it should be able to hold a hostname).
var b [256]byte
// As per /usr/include/X11/Xauth.h.
const familyLocal = 256
fn := os.Getenv("XAUTHORITY")
if fn == "" {
home := os.Getenv("HOME")
if home == "" {
err = errors.New("Xauthority not found: $XAUTHORITY, $HOME not set")
return
}
fn = home + "/.Xauthority"
}
r, err := os.Open(fn)
if err != nil {
return
}
defer r.Close()
br := bufio.NewReader(r)
hostname, err := os.Hostname()
if err != nil {
return
}
for {
var family uint16
var addr, disp, name0, data0 string
family, err = readU16BE(br, b[0:2])
if err != nil {
return
}
addr, err = readStr(br, b[0:])
if err != nil {
return
}
disp, err = readStr(br, b[0:])
if err != nil {
return
}
name0, err = readStr(br, b[0:])
if err != nil {
return
}
data0, err = readStr(br, b[0:])
if err != nil {
return
}
if family == familyLocal && addr == hostname && disp == displayStr {
return name0, data0, nil
}
}
panic("unreachable")
}
......@@ -7,7 +7,7 @@
//
// Code simply using databases should use package sql.
//
// Drivers only need to be aware of a subset of Go's types. The db package
// Drivers only need to be aware of a subset of Go's types. The sql package
// will convert all types into one of the following:
//
// int64
......@@ -94,12 +94,35 @@ type Result interface {
// used by multiple goroutines concurrently.
type Stmt interface {
// Close closes the statement.
//
// Closing a statement should not interrupt any outstanding
// query created from that statement. That is, the following
// order of operations is valid:
//
// * create a driver statement
// * call Query on statement, returning Rows
// * close the statement
// * read from Rows
//
// If closing a statement invalidates currently-running
// queries, the final step above will incorrectly fail.
//
// TODO(bradfitz): possibly remove the restriction above, if
// enough driver authors object and find it complicates their
// code too much. The sql package could be smarter about
// refcounting the statement and closing it at the appropriate
// time.
Close() error
// NumInput returns the number of placeholder parameters.
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
//
// If NumInput returns >= 0, the sql package will sanity check
// argument counts from callers and return errors to the caller
// before the statement's Exec or Query methods are called.
//
// NumInput may also return -1, if the driver doesn't know
// its number of placeholders. In that case, the sql package
// will not sanity check Exec or Query argument counts.
NumInput() int
// Exec executes a query that doesn't return rows, such
......
......@@ -90,6 +90,8 @@ type fakeStmt struct {
cmd string
table string
closed bool
colName []string // used by CREATE, INSERT, SELECT (selected columns)
colType []string // used by CREATE
colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
......@@ -232,6 +234,9 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
stmt.table = parts[0]
stmt.colName = strings.Split(parts[1], ",")
for n, colspec := range strings.Split(parts[2], ",") {
if colspec == "" {
continue
}
nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 {
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
......@@ -342,10 +347,16 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
}
func (s *fakeStmt) Close() error {
s.closed = true
return nil
}
var errClosed = errors.New("fakedb: statement has been closed")
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
if s.closed {
return nil, errClosed
}
err := checkSubsetTypes(args)
if err != nil {
return nil, err
......@@ -405,6 +416,9 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
}
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
if s.closed {
return nil, errClosed
}
err := checkSubsetTypes(args)
if err != nil {
return nil, err
......
......@@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error {
return tx.txi.Rollback()
}
// Prepare creates a prepared statement.
// Prepare creates a prepared statement for use within a transaction.
//
// The statement is only valid within the scope of this transaction.
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) {
// TODO(bradfitz): the restriction that the returned statement
// is only valid for this Transaction is lame and negates a
// lot of the benefit of prepared statements. We could be
// more efficient here and either provide a method to take an
// existing Stmt (created on perhaps a different Conn), and
// re-create it on this Conn if necessary. Or, better: keep a
// map in DB of query string to Stmts, and have Stmt.Execute
// do the right thing and re-prepare if the Conn in use
// doesn't have that prepared statement. But we'll want to
// avoid caching the statement in the case where we only call
// conn.Prepare implicitly (such as in db.Exec or tx.Exec),
// but the caller package can't be holding a reference to the
// returned statement. Perhaps just looking at the reference
// count (by noting Stmt.Close) would be enough. We might also
// want a finalizer on Stmt to drop the reference count.
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
// necessary. Or, better: keep a map in DB of query string to
// Stmts, and have Stmt.Execute do the right thing and
// re-prepare if the Conn in use doesn't have that prepared
// statement. But we'll want to avoid caching the statement
// in the case where we only call conn.Prepare implicitly
// (such as in db.Exec or tx.Exec), but the caller package
// can't be holding a reference to the returned statement.
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
ci, err := tx.grabConn()
if err != nil {
return nil, err
......@@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil
}
// Stmt returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
// per-Conn. See also the big comment in Tx.Prepare.
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
ci, err := tx.grabConn()
if err != nil {
return &Stmt{stickyErr: err}
}
defer tx.releaseConn()
si, err := ci.Prepare(stmt.query)
return &Stmt{
db: tx.db,
tx: tx,
txsi: si,
query: stmt.query,
stickyErr: err,
}
}
// Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
......@@ -448,8 +482,9 @@ type connStmt struct {
// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct {
// Immutable:
db *DB // where we came from
query string // that created the Sttm
db *DB // where we came from
query string // that created the Stmt
stickyErr error // if non-nil, this error is returned for all operations
// If in a transaction, else both nil:
tx *Tx
......@@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
if s.stickyErr != nil {
return nil, nil, nil, s.stickyErr
}
s.mu.Lock()
if s.closed {
s.mu.Unlock()
......@@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
// Close closes the statement.
func (s *Stmt) Close() error {
if s.stickyErr != nil {
return s.stickyErr
}
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
......
......@@ -5,6 +5,7 @@
package sql
import (
"reflect"
"strings"
"testing"
)
......@@ -22,7 +23,6 @@ func newTestDB(t *testing.T, name string) *DB {
exec(t, db, "INSERT|people|name=Alice,age=?", 1)
exec(t, db, "INSERT|people|name=Bob,age=?", 2)
exec(t, db, "INSERT|people|name=Chris,age=?", 3)
}
return db
}
......@@ -44,6 +44,40 @@ func closeDB(t *testing.T, db *DB) {
func TestQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
rows, err := db.Query("SELECT|people|age,name|")
if err != nil {
t.Fatalf("Query: %v", err)
}
type row struct {
age int
name string
}
got := []row{}
for rows.Next() {
var r row
err = rows.Scan(&r.age, &r.name)
if err != nil {
t.Fatalf("Scan: %v", err)
}
got = append(got, r)
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want := []row{
{age: 1, name: "Alice"},
{age: 2, name: "Bob"},
{age: 3, name: "Chris"},
}
if !reflect.DeepEqual(got, want) {
t.Logf(" got: %#v\nwant: %#v", got, want)
}
}
func TestQueryRow(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
var name string
var age int
......@@ -75,6 +109,24 @@ func TestQuery(t *testing.T) {
}
}
func TestStatementErrorAfterClose(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
stmt, err := db.Prepare("SELECT|people|age|name=?")
if err != nil {
t.Fatalf("Prepare: %v", err)
}
err = stmt.Close()
if err != nil {
t.Fatalf("Close: %v", err)
}
var name string
err = stmt.QueryRow("foo").Scan(&name)
if err == nil {
t.Errorf("expected error from QueryRow.Scan after Stmt.Close")
}
}
func TestStatementQueryRow(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
......@@ -114,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) {
}
}
func TestDb(t *testing.T) {
func TestExec(t *testing.T) {
db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
......@@ -154,3 +206,25 @@ func TestDb(t *testing.T) {
}
}
}
func TestTxStmt(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
_, err = tx.Stmt(stmt).Exec("Bobby", 7)
if err != nil {
t.Fatalf("Exec = %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit = %v", err)
}
}
......@@ -244,13 +244,13 @@ func (c *channel) Write(data []byte) (n int, err error) {
packet := make([]byte, 1+4+4+len(todo))
packet[0] = msgChannelData
packet[1] = byte(c.theirId) >> 24
packet[2] = byte(c.theirId) >> 16
packet[3] = byte(c.theirId) >> 8
packet[1] = byte(c.theirId >> 24)
packet[2] = byte(c.theirId >> 16)
packet[3] = byte(c.theirId >> 8)
packet[4] = byte(c.theirId)
packet[5] = byte(len(todo)) >> 24
packet[6] = byte(len(todo)) >> 16
packet[7] = byte(len(todo)) >> 8
packet[5] = byte(len(todo) >> 24)
packet[6] = byte(len(todo) >> 16)
packet[7] = byte(len(todo) >> 8)
packet[8] = byte(len(todo))
copy(packet[9:], todo)
......
......@@ -172,40 +172,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
marshalInt(K, kInt)
h.Write(K)
H := h.Sum()
H := h.Sum(nil)
return H, K, nil
}
// openChan opens a new client channel. The most common session type is "session".
// The full set of valid session types are listed in RFC 4250 4.9.1.
func (c *ClientConn) openChan(typ string) (*clientChan, error) {
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
ChanType: typ,
PeersId: ch.id,
PeersWindow: 1 << 14,
MaxPacketSize: 1 << 15, // RFC 4253 6.1
})); err != nil {
c.chanlist.remove(ch.id)
return nil, err
}
// wait for response
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New(msg.Message)
default:
c.chanlist.remove(ch.id)
return nil, errors.New("Unexpected packet")
}
return ch, nil
}
// mainloop reads incoming messages and routes channel messages
// mainLoop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
// TODO(dfc) signal the underlying close to all channels
......@@ -271,7 +243,7 @@ func (c *ClientConn) mainLoop() {
case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
default:
fmt.Printf("mainLoop: unhandled %#v\n", msg)
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
}
}
}
......@@ -338,27 +310,16 @@ func newClientChan(t *transport, id uint32) *clientChan {
// Close closes the channel. This does not close the underlying connection.
func (c *clientChan) Close() error {
return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
PeersId: c.id,
PeersId: c.peersId,
}))
}
func (c *clientChan) sendChanReq(req channelRequestMsg) error {
if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
return err
}
msg := <-c.msg
if _, ok := msg.(*channelRequestSuccessMsg); ok {
return nil
}
return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
}
// Thread safe channel list.
type chanlist struct {
// protects concurrent access to chans
sync.Mutex
// chans are indexed by the local id of the channel, clientChan.id.
// The PeersId value of messages received by ClientConn.mainloop is
// The PeersId value of messages received by ClientConn.mainLoop is
// used to locate the right local clientChan in this slice.
chans []*clientChan
}
......@@ -395,7 +356,7 @@ func (c *chanlist) remove(id uint32) {
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
win chan int // receives window adjustments
id uint32 // this channel's id
peersId uint32 // the peer's id
rwin int // current rwin size
packetWriter // for sending channelDataMsg
}
......@@ -414,8 +375,8 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
n = len(data)
packet := make([]byte, 0, 9+n)
packet = append(packet, msgChannelData,
byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
byte(w.peersId>>24), byte(w.peersId>>16), byte(w.peersId>>8), byte(w.peersId),
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
err = w.writePacket(append(packet, data...))
w.rwin -= n
return
......@@ -424,7 +385,7 @@ func (w *chanWriter) Write(data []byte) (n int, err error) {
}
func (w *chanWriter) Close() error {
return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.peersId}))
}
// A chanReader represents stdout or stderr of a remote process.
......@@ -433,8 +394,8 @@ type chanReader struct {
// If writes to this channel block, they will block mainLoop, making
// it unable to receive new messages from the remote side.
data chan []byte // receives data from remote
id uint32
packetWriter // for sending windowAdjustMsg
peersId uint32 // the peer's id
packetWriter // for sending windowAdjustMsg
buf []byte
}
......@@ -446,7 +407,7 @@ func (r *chanReader) Read(data []byte) (int, error) {
n := copy(data, r.buf)
r.buf = r.buf[n:]
msg := windowAdjustMsg{
PeersId: r.id,
PeersId: r.peersId,
AdditionalBytes: uint32(n),
}
return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
......@@ -458,7 +419,3 @@ func (r *chanReader) Read(data []byte) (int, error) {
}
panic("unreachable")
}
func (r *chanReader) Close() error {
return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
}
......@@ -70,7 +70,7 @@ func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err err
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(data)
digest := h.Sum()
digest := h.Sum(nil)
return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
}
......
......@@ -224,3 +224,16 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
r = marshalString(r, pubKey)
return ret
}
// safeString sanitises s according to RFC 4251, section 9.2.
// All control characters except tab, carriage return and newline are
// replaced by 0x20.
func safeString(s string) string {
out := []byte(s)
for i, c := range out {
if c < 0x20 && c != 0xd && c != 0xa && c != 0x9 {
out[i] = 0x20
}
}
return string(out)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"testing"
)
var strings = map[string]string{
"\x20\x0d\x0a": "\x20\x0d\x0a",
"flibble": "flibble",
"new\x20line": "new\x20line",
"123456\x07789": "123456 789",
"\t\t\x10\r\n": "\t\t \r\n",
}
func TestSafeString(t *testing.T) {
for s, expected := range strings {
actual := safeString(s)
if expected != actual {
t.Errorf("expected: %v, actual: %v", []byte(expected), []byte(actual))
}
}
}
......@@ -92,9 +92,9 @@ Each ClientConn can support multiple interactive sessions, represented by a Sess
session, err := client.NewSession()
Once a Session is created, you can execute a single command on the remote side
using the Exec method.
using the Run method.
if err := session.Exec("/usr/bin/whoami"); err != nil {
if err := session.Run("/usr/bin/whoami"); err != nil {
panic("Failed to exec: " + err.String())
}
reader := bufio.NewReader(session.Stdin)
......
......@@ -207,11 +207,11 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
marshalInt(K, kInt)
h.Write(K)
H = h.Sum()
H = h.Sum(nil)
h.Reset()
h.Write(H)
hh := h.Sum()
hh := h.Sum(nil)
var sig []byte
switch hostKeyAlgo {
......@@ -478,7 +478,7 @@ userAuthLoop:
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(signedData)
digest := h.Sum()
digest := h.Sum(nil)
rsaKey, ok := parseRSA(pubKey)
if !ok {
return ParseError{msgUserAuthRequest}
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session tests.
import (
"bytes"
"io"
"testing"
)
// dial constructs a new test server and returns a *ClientConn.
func dial(t *testing.T) *ClientConn {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
go func() {
defer l.Close()
conn, err := l.Accept()
if err != nil {
t.Errorf("Unable to accept: %v", err)
return
}
defer conn.Close()
if err := conn.Handshake(); err != nil {
t.Errorf("Unable to handshake: %v", err)
return
}
for {
ch, err := conn.Accept()
if err == io.EOF {
return
}
if err != nil {
t.Errorf("Unable to accept incoming channel request: %v", err)
return
}
if ch.ChannelType() != "session" {
ch.Reject(UnknownChannelType, "unknown channel type")
continue
}
ch.Accept()
go func() {
defer ch.Close()
// this string is returned to stdout
shell := NewServerShell(ch, "golang")
shell.ReadLine()
type exitMsg struct {
PeersId uint32
Request string
WantReply bool
Status uint32
}
// TODO(dfc) casting to the concrete type should not be
// necessary to send a packet.
msg := exitMsg{
PeersId: ch.(*channel).theirId,
Request: "exit-status",
WantReply: false,
Status: 0,
}
ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
}()
}
t.Log("done")
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
return c
}
// Test a simple string is returned to session.Stdout.
func TestSessionShell(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout := new(bytes.Buffer)
session.Stdout = stdout
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
actual := stdout.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
// Test a simple string is returned via StdoutPipe.
func TestSessionStdoutPipe(t *testing.T) {
conn := dial(t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("Unable to request StdoutPipe(): %v", err)
}
var buf bytes.Buffer
if err := session.Shell(); err != nil {
t.Fatalf("Unable to execute command: %s", err)
}
done := make(chan bool, 1)
go func() {
if _, err := io.Copy(&buf, stdout); err != nil {
t.Errorf("Copy of stdout failed: %v", err)
}
done <- true
}()
if err := session.Wait(); err != nil {
t.Fatalf("Remote command did not exit cleanly: %s", err)
}
<-done
actual := buf.String()
if actual != "golang" {
t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
}
}
......@@ -86,12 +86,12 @@ func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tc
clientChan: ch,
Reader: &chanReader{
packetWriter: ch,
id: ch.id,
peersId: ch.peersId,
data: ch.data,
},
Writer: &chanWriter{
packetWriter: ch,
id: ch.id,
peersId: ch.peersId,
win: ch.win,
},
}, nil
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// direct-tcpip functional tests
import (
"net"
"net/http"
"testing"
)
func TestTCPIPHTTP(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
// google.com will generate at least one redirect, possibly three
// depending on your location.
doTest(t, "http://google.com")
}
func TestTCPIPHTTPS(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
doTest(t, "https://encrypted.google.com/")
}
func doTest(t *testing.T, url string) {
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPassword(password(*sshpass)),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("Unable to connect: %s", err)
}
defer conn.Close()
tr := &http.Transport{
Dial: func(n, addr string) (net.Conn, error) {
return conn.Dial(n, addr)
},
}
client := &http.Client{
Transport: tr,
}
resp, err := client.Get(url)
if err != nil {
t.Fatalf("unable to proxy: %s", err)
}
// got a body without error
t.Log(resp)
}
......@@ -123,7 +123,7 @@ func (r *reader) readOnePacket() ([]byte, error) {
if r.mac != nil {
r.mac.Write(packet[:length-1])
if subtle.ConstantTimeCompare(r.mac.Sum(), mac) != 1 {
if subtle.ConstantTimeCompare(r.mac.Sum(nil), mac) != 1 {
return nil, errors.New("ssh: MAC failure")
}
}
......@@ -201,7 +201,7 @@ func (w *writer) writePacket(packet []byte) error {
}
if w.mac != nil {
if _, err := w.Write(w.mac.Sum()); err != nil {
if _, err := w.Write(w.mac.Sum(nil)); err != nil {
return err
}
}
......@@ -297,7 +297,7 @@ func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
h.Write(digestsSoFar)
}
digest := h.Sum()
digest := h.Sum(nil)
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
......@@ -317,9 +317,9 @@ func (t truncatingMAC) Write(data []byte) (int, error) {
return t.hmac.Write(data)
}
func (t truncatingMAC) Sum() []byte {
digest := t.hmac.Sum()
return digest[:t.length]
func (t truncatingMAC) Sum(in []byte) []byte {
out := t.hmac.Sum(in)
return out[:len(in)+t.length]
}
func (t truncatingMAC) Reset() {
......
......@@ -202,7 +202,7 @@ func TestCheck(t *testing.T) {
// For easy debugging w/o changing the testing code,
// if there is a local test file, only test that file.
const testfile = "test.go"
if fi, err := os.Stat(testfile); err == nil && fi.IsRegular() {
if fi, err := os.Stat(testfile); err == nil && !fi.IsDir() {
fmt.Printf("WARNING: Testing only %s (remove it to run all tests)\n", testfile)
check(t, testfile, []string{testfile})
return
......
......@@ -59,7 +59,7 @@ func findPkg(path string) (filename, id string) {
// try extensions
for _, ext := range pkgExts {
filename = noext + ext
if f, err := os.Stat(filename); err == nil && f.IsRegular() {
if f, err := os.Stat(filename); err == nil && !f.IsDir() {
return
}
}
......
......@@ -58,32 +58,32 @@ func testPath(t *testing.T, path string) bool {
return true
}
const maxTime = 3e9 // maximum allotted testing time in ns
const maxTime = 3 * time.Second
func testDir(t *testing.T, dir string, endTime int64) (nimports int) {
func testDir(t *testing.T, dir string, endTime time.Time) (nimports int) {
dirname := filepath.Join(pkgRoot, dir)
list, err := ioutil.ReadDir(dirname)
if err != nil {
t.Errorf("testDir(%s): %s", dirname, err)
}
for _, f := range list {
if time.Nanoseconds() >= endTime {
if time.Now().After(endTime) {
t.Log("testing time used up")
return
}
switch {
case f.IsRegular():
case !f.IsDir():
// try extensions
for _, ext := range pkgExts {
if strings.HasSuffix(f.Name, ext) {
name := f.Name[0 : len(f.Name)-len(ext)] // remove extension
if strings.HasSuffix(f.Name(), ext) {
name := f.Name()[0 : len(f.Name())-len(ext)] // remove extension
if testPath(t, filepath.Join(dir, name)) {
nimports++
}
}
}
case f.IsDirectory():
nimports += testDir(t, filepath.Join(dir, f.Name), endTime)
case f.IsDir():
nimports += testDir(t, filepath.Join(dir, f.Name()), endTime)
}
}
return
......@@ -96,6 +96,6 @@ func TestGcImport(t *testing.T) {
if testPath(t, "./testdata/exports") {
nimports++
}
nimports += testDir(t, "", time.Nanoseconds()+maxTime) // installed packages
nimports += testDir(t, "", time.Now().Add(maxTime)) // installed packages
t.Logf("tested %d imports", nimports)
}
......@@ -47,8 +47,10 @@ func TestFmtInterface(t *testing.T) {
const b32 uint32 = 1<<32 - 1
const b64 uint64 = 1<<64 - 1
var array = []int{1, 2, 3, 4, 5}
var iarray = []interface{}{1, "hello", 2.5, nil}
var array = [5]int{1, 2, 3, 4, 5}
var iarray = [4]interface{}{1, "hello", 2.5, nil}
var slice = array[:]
var islice = iarray[:]
type A struct {
i int
......@@ -327,6 +329,12 @@ var fmttests = []struct {
{"%v", &array, "&[1 2 3 4 5]"},
{"%v", &iarray, "&[1 hello 2.5 <nil>]"},
// slices
{"%v", slice, "[1 2 3 4 5]"},
{"%v", islice, "[1 hello 2.5 <nil>]"},
{"%v", &slice, "&[1 2 3 4 5]"},
{"%v", &islice, "&[1 hello 2.5 <nil>]"},
// complexes with %v
{"%v", 1 + 2i, "(1+2i)"},
{"%v", complex64(1 + 2i), "(1+2i)"},
......@@ -359,6 +367,10 @@ var fmttests = []struct {
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
{"%#v", []int(nil), `[]int(nil)`},
{"%#v", []int{}, `[]int{}`},
{"%#v", array, `[5]int{1, 2, 3, 4, 5}`},
{"%#v", &array, `&[5]int{1, 2, 3, 4, 5}`},
{"%#v", iarray, `[4]interface {}{1, "hello", 2.5, interface {}(nil)}`},
{"%#v", &iarray, `&[4]interface {}{1, "hello", 2.5, interface {}(nil)}`},
{"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
{"%#v", map[int]byte{}, `map[int] uint8{}`},
......
......@@ -877,7 +877,7 @@ BigSwitch:
}
if goSyntax {
p.buf.WriteString(value.Type().String())
if f.IsNil() {
if f.Kind() == reflect.Slice && f.IsNil() {
p.buf.WriteString("(nil)")
break
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment