Commit ab61e9c4 by Ian Lance Taylor

libgo: Update to weekly.2011-11-18.

From-SVN: r182266
parent 6e456f4c
2f4482b89a6b b4a91b693374
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -648,7 +648,8 @@ go_math_files = \ ...@@ -648,7 +648,8 @@ go_math_files = \
go_mime_files = \ go_mime_files = \
go/mime/grammar.go \ go/mime/grammar.go \
go/mime/mediatype.go \ go/mime/mediatype.go \
go/mime/type.go go/mime/type.go \
go/mime/type_unix.go
if LIBGO_IS_RTEMS if LIBGO_IS_RTEMS
go_net_fd_os_file = go/net/fd_select.go go_net_fd_os_file = go/net/fd_select.go
...@@ -770,7 +771,6 @@ go_os_files = \ ...@@ -770,7 +771,6 @@ go_os_files = \
$(go_os_dir_file) \ $(go_os_dir_file) \
go/os/dir.go \ go/os/dir.go \
go/os/env.go \ go/os/env.go \
go/os/env_unix.go \
go/os/error.go \ go/os/error.go \
go/os/error_posix.go \ go/os/error_posix.go \
go/os/exec.go \ go/os/exec.go \
...@@ -1156,6 +1156,7 @@ go_exp_sql_files = \ ...@@ -1156,6 +1156,7 @@ go_exp_sql_files = \
go/exp/sql/sql.go go/exp/sql/sql.go
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \ go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \ go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
...@@ -1164,10 +1165,11 @@ go_exp_ssh_files = \ ...@@ -1164,10 +1165,11 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/shell.go \ go/exp/terminal/terminal.go \
go/exp/terminal/terminal.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/check.go \ go/exp/types/check.go \
go/exp/types/const.go \ go/exp/types/const.go \
...@@ -1546,6 +1548,7 @@ syscall_netlink_file = ...@@ -1546,6 +1548,7 @@ syscall_netlink_file =
endif endif
go_base_syscall_files = \ go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \ go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \ go/syscall/libcall_posix.go \
go/syscall/socket.go \ go/syscall/socket.go \
......
...@@ -1032,7 +1032,8 @@ go_math_files = \ ...@@ -1032,7 +1032,8 @@ go_math_files = \
go_mime_files = \ go_mime_files = \
go/mime/grammar.go \ go/mime/grammar.go \
go/mime/mediatype.go \ go/mime/mediatype.go \
go/mime/type.go go/mime/type.go \
go/mime/type_unix.go
# By default use select with pipes. Most systems should have # By default use select with pipes. Most systems should have
# something better. # something better.
...@@ -1103,7 +1104,6 @@ go_os_files = \ ...@@ -1103,7 +1104,6 @@ go_os_files = \
$(go_os_dir_file) \ $(go_os_dir_file) \
go/os/dir.go \ go/os/dir.go \
go/os/env.go \ go/os/env.go \
go/os/env_unix.go \
go/os/error.go \ go/os/error.go \
go/os/error_posix.go \ go/os/error_posix.go \
go/os/exec.go \ go/os/exec.go \
...@@ -1521,6 +1521,7 @@ go_exp_sql_files = \ ...@@ -1521,6 +1521,7 @@ go_exp_sql_files = \
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/cipher.go \
go/exp/ssh/client.go \ go/exp/ssh/client.go \
go/exp/ssh/client_auth.go \ go/exp/ssh/client_auth.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
...@@ -1529,11 +1530,12 @@ go_exp_ssh_files = \ ...@@ -1529,11 +1530,12 @@ go_exp_ssh_files = \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/shell.go \ go/exp/terminal/terminal.go \
go/exp/terminal/terminal.go go/exp/terminal/util.go
go_exp_types_files = \ go_exp_types_files = \
go/exp/types/check.go \ go/exp/types/check.go \
...@@ -1890,6 +1892,7 @@ go_unicode_utf8_files = \ ...@@ -1890,6 +1892,7 @@ go_unicode_utf8_files = \
# Support for netlink sockets and messages. # Support for netlink sockets and messages.
@LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go @LIBGO_IS_LINUX_TRUE@syscall_netlink_file = go/syscall/netlink_linux.go
go_base_syscall_files = \ go_base_syscall_files = \
go/syscall/env_unix.go \
go/syscall/libcall_support.go \ go/syscall/libcall_support.go \
go/syscall/libcall_posix.go \ go/syscall/libcall_posix.go \
go/syscall/socket.go \ go/syscall/socket.go \
......
...@@ -10,7 +10,6 @@ import ( ...@@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os"
"strings" "strings"
"testing" "testing"
"testing/iotest" "testing/iotest"
...@@ -425,9 +424,9 @@ var errorWriterTests = []errorWriterTest{ ...@@ -425,9 +424,9 @@ var errorWriterTests = []errorWriterTest{
{0, 1, nil, io.ErrShortWrite}, {0, 1, nil, io.ErrShortWrite},
{1, 2, nil, io.ErrShortWrite}, {1, 2, nil, io.ErrShortWrite},
{1, 1, nil, nil}, {1, 1, nil, nil},
{0, 1, os.EPIPE, os.EPIPE}, {0, 1, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 2, os.EPIPE, os.EPIPE}, {1, 2, io.ErrClosedPipe, io.ErrClosedPipe},
{1, 1, os.EPIPE, os.EPIPE}, {1, 1, io.ErrClosedPipe, io.ErrClosedPipe},
} }
func TestWriteErrors(t *testing.T) { func TestWriteErrors(t *testing.T) {
......
...@@ -91,6 +91,11 @@ type rune rune ...@@ -91,6 +91,11 @@ type rune rune
// invocation. // invocation.
type Type int type Type int
// Type1 is here for the purposes of documentation only. It is a stand-in
// for any Go type, but represents the same type for any given function
// invocation.
type Type1 int
// IntegerType is here for the purposes of documentation only. It is a stand-in // IntegerType is here for the purposes of documentation only. It is a stand-in
// for any integer type: int, uint, int8 etc. // for any integer type: int, uint, int8 etc.
type IntegerType int type IntegerType int
...@@ -119,6 +124,11 @@ func append(slice []Type, elems ...Type) []Type ...@@ -119,6 +124,11 @@ func append(slice []Type, elems ...Type) []Type
// len(src) and len(dst). // len(src) and len(dst).
func copy(dst, src []Type) int func copy(dst, src []Type) int
// The delete built-in function deletes the element with the specified key
// (m[key]) from the map. If there is no such element, delete is a no-op.
// If m is nil, delete panics.
func delete(m map[Type]Type1, key Type)
// The len built-in function returns the length of v, according to its type: // The len built-in function returns the length of v, according to its type:
// Array: the number of elements in v. // Array: the number of elements in v.
// Pointer to array: the number of elements in *v (even if v is nil). // Pointer to array: the number of elements in *v (even if v is nil).
...@@ -171,7 +181,7 @@ func complex(r, i FloatType) ComplexType ...@@ -171,7 +181,7 @@ func complex(r, i FloatType) ComplexType
// The return value will be floating point type corresponding to the type of c. // The return value will be floating point type corresponding to the type of c.
func real(c ComplexType) FloatType func real(c ComplexType) FloatType
// The imaginary built-in function returns the imaginary part of the complex // The imag built-in function returns the imaginary part of the complex
// number c. The return value will be floating point type corresponding to // number c. The return value will be floating point type corresponding to
// the type of c. // the type of c.
func imag(c ComplexType) FloatType func imag(c ComplexType) FloatType
......
...@@ -662,48 +662,49 @@ func TestRunes(t *testing.T) { ...@@ -662,48 +662,49 @@ func TestRunes(t *testing.T) {
} }
type TrimTest struct { type TrimTest struct {
f func([]byte, string) []byte f string
in, cutset, out string in, cutset, out string
} }
var trimTests = []TrimTest{ var trimTests = []TrimTest{
{Trim, "abba", "a", "bb"}, {"Trim", "abba", "a", "bb"},
{Trim, "abba", "ab", ""}, {"Trim", "abba", "ab", ""},
{TrimLeft, "abba", "ab", ""}, {"TrimLeft", "abba", "ab", ""},
{TrimRight, "abba", "ab", ""}, {"TrimRight", "abba", "ab", ""},
{TrimLeft, "abba", "a", "bba"}, {"TrimLeft", "abba", "a", "bba"},
{TrimRight, "abba", "a", "abb"}, {"TrimRight", "abba", "a", "abb"},
{Trim, "<tag>", "<>", "tag"}, {"Trim", "<tag>", "<>", "tag"},
{Trim, "* listitem", " *", "listitem"}, {"Trim", "* listitem", " *", "listitem"},
{Trim, `"quote"`, `"`, "quote"}, {"Trim", `"quote"`, `"`, "quote"},
{Trim, "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"}, {"Trim", "\u2C6F\u2C6F\u0250\u0250\u2C6F\u2C6F", "\u2C6F", "\u0250\u0250"},
//empty string tests //empty string tests
{Trim, "abba", "", "abba"}, {"Trim", "abba", "", "abba"},
{Trim, "", "123", ""}, {"Trim", "", "123", ""},
{Trim, "", "", ""}, {"Trim", "", "", ""},
{TrimLeft, "abba", "", "abba"}, {"TrimLeft", "abba", "", "abba"},
{TrimLeft, "", "123", ""}, {"TrimLeft", "", "123", ""},
{TrimLeft, "", "", ""}, {"TrimLeft", "", "", ""},
{TrimRight, "abba", "", "abba"}, {"TrimRight", "abba", "", "abba"},
{TrimRight, "", "123", ""}, {"TrimRight", "", "123", ""},
{TrimRight, "", "", ""}, {"TrimRight", "", "", ""},
{TrimRight, "☺\xc0", "☺", "☺\xc0"}, {"TrimRight", "☺\xc0", "☺", "☺\xc0"},
} }
func TestTrim(t *testing.T) { func TestTrim(t *testing.T) {
for _, tc := range trimTests { for _, tc := range trimTests {
actual := string(tc.f([]byte(tc.in), tc.cutset)) name := tc.f
var name string var f func([]byte, string) []byte
switch tc.f { switch name {
case Trim: case "Trim":
name = "Trim" f = Trim
case TrimLeft: case "TrimLeft":
name = "TrimLeft" f = TrimLeft
case TrimRight: case "TrimRight":
name = "TrimRight" f = TrimRight
default: default:
t.Error("Undefined trim function") t.Error("Undefined trim function %s", name)
} }
actual := string(f([]byte(tc.in), tc.cutset))
if actual != tc.out { if actual != tc.out {
t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out) t.Errorf("%s(%q, %q) = %q; want %q", name, tc.in, tc.cutset, actual, tc.out)
} }
......
...@@ -19,7 +19,6 @@ import ( ...@@ -19,7 +19,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
) )
// Order specifies the bit ordering in an LZW data stream. // Order specifies the bit ordering in an LZW data stream.
...@@ -212,8 +211,10 @@ func (d *decoder) flush() { ...@@ -212,8 +211,10 @@ func (d *decoder) flush() {
d.o = 0 d.o = 0
} }
var errClosed = errors.New("compress/lzw: reader/writer is closed")
func (d *decoder) Close() error { func (d *decoder) Close() error {
d.err = os.EINVAL // in case any Reads come along d.err = errClosed // in case any Reads come along
return nil return nil
} }
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
) )
// A writer is a buffered, flushable writer. // A writer is a buffered, flushable writer.
...@@ -49,8 +48,9 @@ const ( ...@@ -49,8 +48,9 @@ const (
type encoder struct { type encoder struct {
// w is the writer that compressed bytes are written to. // w is the writer that compressed bytes are written to.
w writer w writer
// write, bits, nBits and width are the state for converting a code stream // order, write, bits, nBits and width are the state for
// into a byte stream. // converting a code stream into a byte stream.
order Order
write func(*encoder, uint32) error write func(*encoder, uint32) error
bits uint32 bits uint32
nBits uint nBits uint
...@@ -64,7 +64,7 @@ type encoder struct { ...@@ -64,7 +64,7 @@ type encoder struct {
// call. It is equal to invalidCode if there was no such call. // call. It is equal to invalidCode if there was no such call.
savedCode uint32 savedCode uint32
// err is the first error encountered during writing. Closing the encoder // err is the first error encountered during writing. Closing the encoder
// will make any future Write calls return os.EINVAL. // will make any future Write calls return errClosed
err error err error
// table is the hash table from 20-bit keys to 12-bit values. Each table // table is the hash table from 20-bit keys to 12-bit values. Each table
// entry contains key<<12|val and collisions resolve by linear probing. // entry contains key<<12|val and collisions resolve by linear probing.
...@@ -191,13 +191,13 @@ loop: ...@@ -191,13 +191,13 @@ loop:
// flush e's underlying writer. // flush e's underlying writer.
func (e *encoder) Close() error { func (e *encoder) Close() error {
if e.err != nil { if e.err != nil {
if e.err == os.EINVAL { if e.err == errClosed {
return nil return nil
} }
return e.err return e.err
} }
// Make any future calls to Write return os.EINVAL. // Make any future calls to Write return errClosed.
e.err = os.EINVAL e.err = errClosed
// Write the savedCode if valid. // Write the savedCode if valid.
if e.savedCode != invalidCode { if e.savedCode != invalidCode {
if err := e.write(e, e.savedCode); err != nil { if err := e.write(e, e.savedCode); err != nil {
...@@ -214,7 +214,7 @@ func (e *encoder) Close() error { ...@@ -214,7 +214,7 @@ func (e *encoder) Close() error {
} }
// Write the final bits. // Write the final bits.
if e.nBits > 0 { if e.nBits > 0 {
if e.write == (*encoder).writeMSB { if e.order == MSB {
e.bits >>= 24 e.bits >>= 24
} }
if err := e.w.WriteByte(uint8(e.bits)); err != nil { if err := e.w.WriteByte(uint8(e.bits)); err != nil {
...@@ -250,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser { ...@@ -250,6 +250,7 @@ func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
lw := uint(litWidth) lw := uint(litWidth)
return &encoder{ return &encoder{
w: bw, w: bw,
order: order,
write: write, write: write,
width: 1 + lw, width: 1 + lw,
litWidth: lw, litWidth: lw,
......
...@@ -50,10 +50,6 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) { ...@@ -50,10 +50,6 @@ func testFile(t *testing.T, fn string, order Order, litWidth int) {
return return
} }
_, err1 := lzww.Write(b[:n]) _, err1 := lzww.Write(b[:n])
if err1 == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reportable) error broke the pipe.
return
}
if err1 != nil { if err1 != nil {
t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1) t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
return return
......
...@@ -59,10 +59,6 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) { ...@@ -59,10 +59,6 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
} }
defer zlibw.Close() defer zlibw.Close()
_, err = zlibw.Write(b0) _, err = zlibw.Write(b0)
if err == os.EPIPE {
// Fail, but do not report the error, as some other (presumably reported) error broke the pipe.
return
}
if err != nil { if err != nil {
t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err) t.Errorf("%s (level=%d, dict=%q): %v", fn, level, d, err)
return return
......
...@@ -41,7 +41,7 @@ func NewCipher(key []byte) (*Cipher, error) { ...@@ -41,7 +41,7 @@ func NewCipher(key []byte) (*Cipher, error) {
} }
// BlockSize returns the AES block size, 16 bytes. // BlockSize returns the AES block size, 16 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
......
...@@ -54,7 +54,7 @@ func NewSaltedCipher(key, salt []byte) (*Cipher, error) { ...@@ -54,7 +54,7 @@ func NewSaltedCipher(key, salt []byte) (*Cipher, error) {
} }
// BlockSize returns the Blowfish block size, 8 bytes. // BlockSize returns the Blowfish block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
......
...@@ -28,16 +28,16 @@ func (r *rngReader) Read(b []byte) (n int, err error) { ...@@ -28,16 +28,16 @@ func (r *rngReader) Read(b []byte) (n int, err error) {
if r.prov == 0 { if r.prov == 0 {
const provType = syscall.PROV_RSA_FULL const provType = syscall.PROV_RSA_FULL
const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT const flags = syscall.CRYPT_VERIFYCONTEXT | syscall.CRYPT_SILENT
errno := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags) err := syscall.CryptAcquireContext(&r.prov, nil, nil, provType, flags)
if errno != 0 { if err != nil {
r.mu.Unlock() r.mu.Unlock()
return 0, os.NewSyscallError("CryptAcquireContext", errno) return 0, os.NewSyscallError("CryptAcquireContext", err)
} }
} }
r.mu.Unlock() r.mu.Unlock()
errno := syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0]) err = syscall.CryptGenRandom(r.prov, uint32(len(b)), &b[0])
if errno != 0 { if err != nil {
return 0, os.NewSyscallError("CryptGenRandom", errno) return 0, os.NewSyscallError("CryptGenRandom", err)
} }
return len(b), nil return len(b), nil
} }
...@@ -5,16 +5,16 @@ ...@@ -5,16 +5,16 @@
package rand package rand
import ( import (
"errors"
"io" "io"
"math/big" "math/big"
"os"
) )
// Prime returns a number, p, of the given size, such that p is prime // Prime returns a number, p, of the given size, such that p is prime
// with high probability. // with high probability.
func Prime(rand io.Reader, bits int) (p *big.Int, err error) { func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
if bits < 1 { if bits < 1 {
err = os.EINVAL err = errors.New("crypto/rand: prime size must be positive")
} }
b := uint(bits % 8) b := uint(bits % 8)
......
...@@ -93,7 +93,8 @@ func (c *Conn) SetTimeout(nsec int64) error { ...@@ -93,7 +93,8 @@ func (c *Conn) SetTimeout(nsec int64) error {
} }
// SetReadTimeout sets the time (in nanoseconds) that // SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning os.EAGAIN. // Read will wait for data before returning a net.Error
// with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline. // Setting nsec == 0 (the default) disables the deadline.
func (c *Conn) SetReadTimeout(nsec int64) error { func (c *Conn) SetReadTimeout(nsec int64) error {
return c.conn.SetReadTimeout(nsec) return c.conn.SetReadTimeout(nsec)
...@@ -737,7 +738,7 @@ func (c *Conn) Write(b []byte) (n int, err error) { ...@@ -737,7 +738,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return c.writeRecord(recordTypeApplicationData, b) return c.writeRecord(recordTypeApplicationData, b)
} }
// Read can be made to time out and return err == os.EAGAIN // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetTimeout and SetReadTimeout. // after a fixed time limit; see SetTimeout and SetReadTimeout.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil { if err = c.Handshake(); err != nil {
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
package tls package tls
import "bytes"
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte raw []byte
vers uint16 vers uint16
...@@ -18,6 +20,25 @@ type clientHelloMsg struct { ...@@ -18,6 +20,25 @@ type clientHelloMsg struct {
supportedPoints []uint8 supportedPoints []uint8
} }
func (m *clientHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*clientHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints)
}
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -309,6 +330,23 @@ type serverHelloMsg struct { ...@@ -309,6 +330,23 @@ type serverHelloMsg struct {
ocspStapling bool ocspStapling bool
} }
func (m *serverHelloMsg) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
m.cipherSuite == m1.cipherSuite &&
m.compressionMethod == m1.compressionMethod &&
m.nextProtoNeg == m1.nextProtoNeg &&
eqStrings(m.nextProtos, m1.nextProtos) &&
m.ocspStapling == m1.ocspStapling
}
func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -463,6 +501,16 @@ type certificateMsg struct { ...@@ -463,6 +501,16 @@ type certificateMsg struct {
certificates [][]byte certificates [][]byte
} }
func (m *certificateMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
eqByteSlices(m.certificates, m1.certificates)
}
func (m *certificateMsg) marshal() (x []byte) { func (m *certificateMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct { ...@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
key []byte key []byte
} }
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*serverKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.key, m1.key)
}
func (m *serverKeyExchangeMsg) marshal() []byte { func (m *serverKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -571,6 +629,17 @@ type certificateStatusMsg struct { ...@@ -571,6 +629,17 @@ type certificateStatusMsg struct {
response []byte response []byte
} }
func (m *certificateStatusMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateStatusMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.statusType == m1.statusType &&
bytes.Equal(m.response, m1.response)
}
func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { ...@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{} type serverHelloDoneMsg struct{}
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
_, ok := i.(*serverHelloDoneMsg)
return ok
}
func (m *serverHelloDoneMsg) marshal() []byte { func (m *serverHelloDoneMsg) marshal() []byte {
x := make([]byte, 4) x := make([]byte, 4)
x[0] = typeServerHelloDone x[0] = typeServerHelloDone
...@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct { ...@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
ciphertext []byte ciphertext []byte
} }
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
m1, ok := i.(*clientKeyExchangeMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.ciphertext, m1.ciphertext)
}
func (m *clientKeyExchangeMsg) marshal() []byte { func (m *clientKeyExchangeMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -671,6 +755,16 @@ type finishedMsg struct { ...@@ -671,6 +755,16 @@ type finishedMsg struct {
verifyData []byte verifyData []byte
} }
func (m *finishedMsg) equal(i interface{}) bool {
m1, ok := i.(*finishedMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.verifyData, m1.verifyData)
}
func (m *finishedMsg) marshal() (x []byte) { func (m *finishedMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -698,6 +792,16 @@ type nextProtoMsg struct { ...@@ -698,6 +792,16 @@ type nextProtoMsg struct {
proto string proto string
} }
func (m *nextProtoMsg) equal(i interface{}) bool {
m1, ok := i.(*nextProtoMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.proto == m1.proto
}
func (m *nextProtoMsg) marshal() []byte { func (m *nextProtoMsg) marshal() []byte {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -759,6 +863,17 @@ type certificateRequestMsg struct { ...@@ -759,6 +863,17 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte certificateAuthorities [][]byte
} }
func (m *certificateRequestMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateRequestMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
}
func (m *certificateRequestMsg) marshal() (x []byte) { func (m *certificateRequestMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -859,6 +974,16 @@ type certificateVerifyMsg struct { ...@@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
signature []byte signature []byte
} }
func (m *certificateVerifyMsg) equal(i interface{}) bool {
m1, ok := i.(*certificateVerifyMsg)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
bytes.Equal(m.signature, m1.signature)
}
func (m *certificateVerifyMsg) marshal() (x []byte) { func (m *certificateVerifyMsg) marshal() (x []byte) {
if m.raw != nil { if m.raw != nil {
return m.raw return m.raw
...@@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool { ...@@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
return true return true
} }
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqStrings(x, y []string) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if y[i] != v {
return false
}
}
return true
}
func eqByteSlices(x, y [][]byte) bool {
if len(x) != len(y) {
return false
}
for i, v := range x {
if !bytes.Equal(v, y[i]) {
return false
}
}
return true
}
...@@ -27,10 +27,12 @@ var tests = []interface{}{ ...@@ -27,10 +27,12 @@ var tests = []interface{}{
type testMessage interface { type testMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool unmarshal([]byte) bool
equal(interface{}) bool
} }
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
for i, iface := range tests { for i, iface := range tests {
ty := reflect.ValueOf(iface).Type() ty := reflect.ValueOf(iface).Type()
...@@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) { ...@@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
} }
m2.marshal() // to fill any marshal cache in the message m2.marshal() // to fill any marshal cache in the message
if !reflect.DeepEqual(m1, m2) { if !m1.equal(m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break break
} }
......
...@@ -12,8 +12,8 @@ import ( ...@@ -12,8 +12,8 @@ import (
) )
func loadStore(roots *x509.CertPool, name string) { func loadStore(roots *x509.CertPool, name string) {
store, errno := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name)) store, err := syscall.CertOpenSystemStore(syscall.InvalidHandle, syscall.StringToUTF16Ptr(name))
if errno != 0 { if err != nil {
return return
} }
......
...@@ -44,7 +44,7 @@ func NewCipher(key []byte) (*Cipher, error) { ...@@ -44,7 +44,7 @@ func NewCipher(key []byte) (*Cipher, error) {
} }
// BlockSize returns the XTEA block size, 8 bytes. // BlockSize returns the XTEA block size, 8 bytes.
// It is necessary to satisfy the Cipher interface in the // It is necessary to satisfy the Block interface in the
// package "crypto/cipher". // package "crypto/cipher".
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Large data benchmark.
// The JSON data is a summary of agl's changes in the
// go, webkit, and chromium open source projects.
// We benchmark converting between the JSON form
// and in-memory data structures.
package json
import (
"bytes"
"compress/gzip"
"io/ioutil"
"os"
"testing"
)
type codeResponse struct {
Tree *codeNode `json:"tree"`
Username string `json:"username"`
}
type codeNode struct {
Name string `json:"name"`
Kids []*codeNode `json:"kids"`
CLWeight float64 `json:"cl_weight"`
Touches int `json:"touches"`
MinT int64 `json:"min_t"`
MaxT int64 `json:"max_t"`
MeanT int64 `json:"mean_t"`
}
var codeJSON []byte
var codeStruct codeResponse
func codeInit() {
f, err := os.Open("testdata/code.json.gz")
if err != nil {
panic(err)
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
panic(err)
}
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(err)
}
codeJSON = data
if err := Unmarshal(codeJSON, &codeStruct); err != nil {
panic("unmarshal code.json: " + err.Error())
}
if data, err = Marshal(&codeStruct); err != nil {
panic("marshal code.json: " + err.Error())
}
if !bytes.Equal(data, codeJSON) {
println("different lengths", len(data), len(codeJSON))
for i := 0; i < len(data) && i < len(codeJSON); i++ {
if data[i] != codeJSON[i] {
println("re-marshal: changed at byte", i)
println("orig: ", string(codeJSON[i-10:i+10]))
println("new: ", string(data[i-10:i+10]))
break
}
}
panic("re-marshal code.json: different result")
}
}
func BenchmarkCodeEncoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
enc := NewEncoder(ioutil.Discard)
for i := 0; i < b.N; i++ {
if err := enc.Encode(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeMarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
if _, err := Marshal(&codeStruct); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeDecoder(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var buf bytes.Buffer
dec := NewDecoder(&buf)
var r codeResponse
for i := 0; i < b.N; i++ {
buf.Write(codeJSON)
// hide EOF
buf.WriteByte('\n')
buf.WriteByte('\n')
buf.WriteByte('\n')
if err := dec.Decode(&r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshal(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
for i := 0; i < b.N; i++ {
var r codeResponse
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
func BenchmarkCodeUnmarshalReuse(b *testing.B) {
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
var r codeResponse
for i := 0; i < b.N; i++ {
if err := Unmarshal(codeJSON, &r); err != nil {
panic(err)
}
}
b.SetBytes(int64(len(codeJSON)))
}
...@@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) { ...@@ -227,7 +227,7 @@ func (d *decodeState) value(v reflect.Value) {
// d.scan thinks we're still at the beginning of the item. // d.scan thinks we're still at the beginning of the item.
// Feed in an empty string - the shortest, simplest value - // Feed in an empty string - the shortest, simplest value -
// so that it knows we got to the end of the value. // so that it knows we got to the end of the value.
if d.scan.step == stateRedo { if d.scan.redo {
panic("redo") panic("redo")
} }
d.scan.step(&d.scan, '"') d.scan.step(&d.scan, '"')
...@@ -381,6 +381,7 @@ func (d *decodeState) array(v reflect.Value) { ...@@ -381,6 +381,7 @@ func (d *decodeState) array(v reflect.Value) {
d.error(errPhase) d.error(errPhase)
} }
} }
if i < av.Len() { if i < av.Len() {
if !sv.IsValid() { if !sv.IsValid() {
// Array. Zero the rest. // Array. Zero the rest.
...@@ -392,6 +393,9 @@ func (d *decodeState) array(v reflect.Value) { ...@@ -392,6 +393,9 @@ func (d *decodeState) array(v reflect.Value) {
sv.SetLen(i) sv.SetLen(i)
} }
} }
if i == 0 && av.Kind() == reflect.Slice && sv.IsNil() {
sv.Set(reflect.MakeSlice(sv.Type(), 0, 0))
}
} }
// object consumes an object from d.data[d.off-1:], decoding into the value v. // object consumes an object from d.data[d.off-1:], decoding into the value v.
......
...@@ -80,6 +80,9 @@ type scanner struct { ...@@ -80,6 +80,9 @@ type scanner struct {
// on a 64-bit Mac Mini, and it's nicer to read. // on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, int) int step func(*scanner, int) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values. // Stack of what we're in the middle of - array values, object keys, object values.
parseState []int parseState []int
...@@ -87,6 +90,7 @@ type scanner struct { ...@@ -87,6 +90,7 @@ type scanner struct {
err error err error
// 1-byte redo (see undo method) // 1-byte redo (see undo method)
redo bool
redoCode int redoCode int
redoState func(*scanner, int) int redoState func(*scanner, int) int
...@@ -135,6 +139,8 @@ func (s *scanner) reset() { ...@@ -135,6 +139,8 @@ func (s *scanner) reset() {
s.step = stateBeginValue s.step = stateBeginValue
s.parseState = s.parseState[0:0] s.parseState = s.parseState[0:0]
s.err = nil s.err = nil
s.redo = false
s.endTop = false
} }
// eof tells the scanner that the end of input has been reached. // eof tells the scanner that the end of input has been reached.
...@@ -143,11 +149,11 @@ func (s *scanner) eof() int { ...@@ -143,11 +149,11 @@ func (s *scanner) eof() int {
if s.err != nil { if s.err != nil {
return scanError return scanError
} }
if s.step == stateEndTop { if s.endTop {
return scanEnd return scanEnd
} }
s.step(s, ' ') s.step(s, ' ')
if s.step == stateEndTop { if s.endTop {
return scanEnd return scanEnd
} }
if s.err == nil { if s.err == nil {
...@@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) { ...@@ -166,8 +172,10 @@ func (s *scanner) pushParseState(p int) {
func (s *scanner) popParseState() { func (s *scanner) popParseState() {
n := len(s.parseState) - 1 n := len(s.parseState) - 1
s.parseState = s.parseState[0:n] s.parseState = s.parseState[0:n]
s.redo = false
if n == 0 { if n == 0 {
s.step = stateEndTop s.step = stateEndTop
s.endTop = true
} else { } else {
s.step = stateEndValue s.step = stateEndValue
} }
...@@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int { ...@@ -269,6 +277,7 @@ func stateEndValue(s *scanner, c int) int {
if n == 0 { if n == 0 {
// Completed top-level before the current byte. // Completed top-level before the current byte.
s.step = stateEndTop s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c) return stateEndTop(s, c)
} }
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
...@@ -606,16 +615,18 @@ func quoteChar(c int) string { ...@@ -606,16 +615,18 @@ func quoteChar(c int) string {
// undo causes the scanner to return scanCode from the next state transition. // undo causes the scanner to return scanCode from the next state transition.
// This gives callers a simple 1-byte undo mechanism. // This gives callers a simple 1-byte undo mechanism.
func (s *scanner) undo(scanCode int) { func (s *scanner) undo(scanCode int) {
if s.step == stateRedo { if s.redo {
panic("invalid use of scanner") panic("json: invalid use of scanner")
} }
s.redoCode = scanCode s.redoCode = scanCode
s.redoState = s.step s.redoState = s.step
s.step = stateRedo s.step = stateRedo
s.redo = true
} }
// stateRedo helps implement the scanner's 1-byte undo. // stateRedo helps implement the scanner's 1-byte undo.
func stateRedo(s *scanner, c int) int { func stateRedo(s *scanner, c int) int {
s.redo = false
s.step = s.redoState s.step = s.redoState
return s.redoCode return s.redoCode
} }
...@@ -186,11 +186,12 @@ func TestNextValueBig(t *testing.T) { ...@@ -186,11 +186,12 @@ func TestNextValueBig(t *testing.T) {
} }
} }
var benchScan scanner
func BenchmarkSkipValue(b *testing.B) { func BenchmarkSkipValue(b *testing.B) {
initBig() initBig()
var scan scanner
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
nextValue(jsonBig, &scan) nextValue(jsonBig, &benchScan)
} }
b.SetBytes(int64(len(jsonBig))) b.SetBytes(int64(len(jsonBig)))
} }
......
...@@ -7,7 +7,6 @@ package xml ...@@ -7,7 +7,6 @@ package xml
import ( import (
"bytes" "bytes"
"io" "io"
"os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
...@@ -43,17 +42,17 @@ var rawTokens = []Token{ ...@@ -43,17 +42,17 @@ var rawTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")), CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"", "hello"}}, EndElement{Name{"", "hello"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "goodbye"}, nil}, StartElement{Name{"", "goodbye"}, []Attr{}},
EndElement{Name{"", "goodbye"}}, EndElement{Name{"", "goodbye"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"", "inner"}, nil}, StartElement{Name{"", "inner"}, []Attr{}},
EndElement{Name{"", "inner"}}, EndElement{Name{"", "inner"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
EndElement{Name{"", "outer"}}, EndElement{Name{"", "outer"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"tag", "name"}, nil}, StartElement{Name{"tag", "name"}, []Attr{}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
CharData([]byte("Some text here.")), CharData([]byte("Some text here.")),
CharData([]byte("\n ")), CharData([]byte("\n ")),
...@@ -77,17 +76,17 @@ var cookedTokens = []Token{ ...@@ -77,17 +76,17 @@ var cookedTokens = []Token{
CharData([]byte("World <>'\" 白鵬翔")), CharData([]byte("World <>'\" 白鵬翔")),
EndElement{Name{"ns2", "hello"}}, EndElement{Name{"ns2", "hello"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "goodbye"}, nil}, StartElement{Name{"ns2", "goodbye"}, []Attr{}},
EndElement{Name{"ns2", "goodbye"}}, EndElement{Name{"ns2", "goodbye"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns2", "inner"}, nil}, StartElement{Name{"ns2", "inner"}, []Attr{}},
EndElement{Name{"ns2", "inner"}}, EndElement{Name{"ns2", "inner"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
EndElement{Name{"ns2", "outer"}}, EndElement{Name{"ns2", "outer"}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
StartElement{Name{"ns3", "name"}, nil}, StartElement{Name{"ns3", "name"}, []Attr{}},
CharData([]byte("\n ")), CharData([]byte("\n ")),
CharData([]byte("Some text here.")), CharData([]byte("Some text here.")),
CharData([]byte("\n ")), CharData([]byte("\n ")),
...@@ -105,7 +104,7 @@ var rawTokensAltEncoding = []Token{ ...@@ -105,7 +104,7 @@ var rawTokensAltEncoding = []Token{
CharData([]byte("\n")), CharData([]byte("\n")),
ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)},
CharData([]byte("\n")), CharData([]byte("\n")),
StartElement{Name{"", "tag"}, nil}, StartElement{Name{"", "tag"}, []Attr{}},
CharData([]byte("value")), CharData([]byte("value")),
EndElement{Name{"", "tag"}}, EndElement{Name{"", "tag"}},
} }
...@@ -205,7 +204,7 @@ func (d *downCaser) ReadByte() (c byte, err error) { ...@@ -205,7 +204,7 @@ func (d *downCaser) ReadByte() (c byte, err error) {
func (d *downCaser) Read(p []byte) (int, error) { func (d *downCaser) Read(p []byte) (int, error) {
d.t.Fatalf("unexpected Read call on downCaser reader") d.t.Fatalf("unexpected Read call on downCaser reader")
return 0, os.EINVAL panic("unreachable")
} }
func TestRawTokenAltEncoding(t *testing.T) { func TestRawTokenAltEncoding(t *testing.T) {
......
...@@ -105,9 +105,9 @@ func (w *Watcher) AddWatch(path string, flags uint32) error { ...@@ -105,9 +105,9 @@ func (w *Watcher) AddWatch(path string, flags uint32) error {
watchEntry.flags |= flags watchEntry.flags |= flags
flags |= syscall.IN_MASK_ADD flags |= syscall.IN_MASK_ADD
} }
wd, errno := syscall.InotifyAddWatch(w.fd, path, flags) wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
if wd == -1 { if err != nil {
return &os.PathError{"inotify_add_watch", path, os.Errno(errno)} return &os.PathError{"inotify_add_watch", path, err}
} }
if !found { if !found {
...@@ -139,14 +139,10 @@ func (w *Watcher) RemoveWatch(path string) error { ...@@ -139,14 +139,10 @@ func (w *Watcher) RemoveWatch(path string) error {
// readEvents reads from the inotify file descriptor, converts the // readEvents reads from the inotify file descriptor, converts the
// received events into Event objects and sends them via the Event channel // received events into Event objects and sends them via the Event channel
func (w *Watcher) readEvents() { func (w *Watcher) readEvents() {
var ( var buf [syscall.SizeofInotifyEvent * 4096]byte
buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
n int // Number of bytes read with read()
errno int // Syscall errno
)
for { for {
n, errno = syscall.Read(w.fd, buf[0:]) n, err := syscall.Read(w.fd, buf[0:])
// See if there is a message on the "done" channel // See if there is a message on the "done" channel
var done bool var done bool
select { select {
...@@ -156,16 +152,16 @@ func (w *Watcher) readEvents() { ...@@ -156,16 +152,16 @@ func (w *Watcher) readEvents() {
// If EOF or a "done" message is received // If EOF or a "done" message is received
if n == 0 || done { if n == 0 || done {
errno := syscall.Close(w.fd) err := syscall.Close(w.fd)
if errno == -1 { if err != nil {
w.Error <- os.NewSyscallError("close", errno) w.Error <- os.NewSyscallError("close", err)
} }
close(w.Event) close(w.Event)
close(w.Error) close(w.Error)
return return
} }
if n < 0 { if n < 0 {
w.Error <- os.NewSyscallError("read", errno) w.Error <- os.NewSyscallError("read", err)
continue continue
} }
if n < syscall.SizeofInotifyEvent { if n < syscall.SizeofInotifyEvent {
......
...@@ -14,6 +14,21 @@ import ( ...@@ -14,6 +14,21 @@ import (
"strconv" "strconv"
) )
// subsetTypeArgs takes a slice of arguments from callers of the sql
// package and converts them into a slice of the driver package's
// "subset types".
func subsetTypeArgs(args []interface{}) ([]interface{}, error) {
out := make([]interface{}, len(args))
for n, arg := range args {
var err error
out[n], err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n+1, err)
}
}
return out, nil
}
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information. // An error is returned if the copy would result in loss of information.
// dest should be a pointer type. // dest should be a pointer type.
......
...@@ -36,19 +36,22 @@ type Driver interface { ...@@ -36,19 +36,22 @@ type Driver interface {
Open(name string) (Conn, error) Open(name string) (Conn, error)
} }
// Execer is an optional interface that may be implemented by a Driver // ErrSkip may be returned by some optional interfaces' methods to
// or a Conn. // indicate at runtime that the fast path is unavailable and the sql
// // package should continue as if the optional interface was not
// If a Driver does not implement Execer, the sql package's DB.Exec // implemented. ErrSkip is only supported where explicitly
// method first obtains a free connection from its free pool or from // documented.
// the driver's Open method. Execer should only be implemented by var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
// drivers that can provide a more efficient implementation.
// Execer is an optional interface that may be implemented by a Conn.
// //
// If a Conn does not implement Execer, the db package's DB.Exec will // If a Conn does not implement Execer, the db package's DB.Exec will
// first prepare a query, execute the statement, and then close the // first prepare a query, execute the statement, and then close the
// statement. // statement.
// //
// All arguments are of a subset type as defined in the package docs. // All arguments are of a subset type as defined in the package docs.
//
// Exec may return ErrSkip.
type Execer interface { type Execer interface {
Exec(query string, args []interface{}) (Result, error) Exec(query string, args []interface{}) (Result, error)
} }
...@@ -94,6 +97,9 @@ type Stmt interface { ...@@ -94,6 +97,9 @@ type Stmt interface {
Close() error Close() error
// NumInput returns the number of placeholder parameters. // NumInput returns the number of placeholder parameters.
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
NumInput() int NumInput() int
// Exec executes a query that doesn't return rows, such // Exec executes a query that doesn't return rows, such
...@@ -135,6 +141,8 @@ type Rows interface { ...@@ -135,6 +141,8 @@ type Rows interface {
// The dest slice may be populated with only with values // The dest slice may be populated with only with values
// of subset types defined above, but excluding string. // of subset types defined above, but excluding string.
// All string values must be converted to []byte. // All string values must be converted to []byte.
//
// Next should return io.EOF when there are no more rows.
Next(dest []interface{}) error Next(dest []interface{}) error
} }
......
...@@ -195,6 +195,29 @@ func (c *fakeConn) Close() error { ...@@ -195,6 +195,29 @@ func (c *fakeConn) Close() error {
return nil return nil
} }
func checkSubsetTypes(args []interface{}) error {
for n, arg := range args {
switch arg.(type) {
case int64, float64, bool, nil, []byte, string:
default:
return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
}
}
return nil
}
func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
// This is an optional interface, but it's implemented here
// just to check that all the args of of the proper types.
// ErrSkip is returned so the caller acts as if we didn't
// implement this at all.
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
return nil, driver.ErrSkip
}
func errf(msg string, args ...interface{}) error { func errf(msg string, args ...interface{}) error {
return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
} }
...@@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error { ...@@ -323,6 +346,11 @@ func (s *fakeStmt) Close() error {
} }
func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) { func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db db := s.c.db
switch s.cmd { switch s.cmd {
case "WIPE": case "WIPE":
...@@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) { ...@@ -377,6 +405,11 @@ func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
} }
func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) { func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
err := checkSubsetTypes(args)
if err != nil {
return nil, err
}
db := s.c.db db := s.c.db
if len(args) != s.placeholders { if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")
......
...@@ -88,8 +88,9 @@ type DB struct { ...@@ -88,8 +88,9 @@ type DB struct {
driver driver.Driver driver driver.Driver
dsn string dsn string
mu sync.Mutex mu sync.Mutex // protects freeConn and closed
freeConn []driver.Conn freeConn []driver.Conn
closed bool
} }
// Open opens a database specified by its database driver name and a // Open opens a database specified by its database driver name and a
...@@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) { ...@@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) {
return &DB{driver: driver, dsn: dataSourceName}, nil return &DB{driver: driver, dsn: dataSourceName}, nil
} }
// Close closes the database, releasing any open resources.
func (db *DB) Close() error {
db.mu.Lock()
defer db.mu.Unlock()
var err error
for _, c := range db.freeConn {
err1 := c.Close()
if err1 != nil {
err = err1
}
}
db.freeConn = nil
db.closed = true
return err
}
func (db *DB) maxIdleConns() int { func (db *DB) maxIdleConns() int {
const defaultMaxIdleConns = 2 const defaultMaxIdleConns = 2
// TODO(bradfitz): ask driver, if supported, for its default preference // TODO(bradfitz): ask driver, if supported, for its default preference
...@@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int { ...@@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int {
// conn returns a newly-opened or cached driver.Conn // conn returns a newly-opened or cached driver.Conn
func (db *DB) conn() (driver.Conn, error) { func (db *DB) conn() (driver.Conn, error) {
db.mu.Lock() db.mu.Lock()
if db.closed {
return nil, errors.New("sql: database is closed")
}
if n := len(db.freeConn); n > 0 { if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1] conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1] db.freeConn = db.freeConn[:n-1]
...@@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { ...@@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
} }
func (db *DB) putConn(c driver.Conn) { func (db *DB) putConn(c driver.Conn) {
if n := len(db.freeConn); n < db.maxIdleConns() { db.mu.Lock()
defer db.mu.Unlock()
if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() {
db.freeConn = append(db.freeConn, c) db.freeConn = append(db.freeConn, c)
return return
} }
db.closeConn(c) db.closeConn(c) // TODO(bradfitz): release lock before calling this?
} }
func (db *DB) closeConn(c driver.Conn) { func (db *DB) closeConn(c driver.Conn) {
...@@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) { ...@@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
// Exec executes a query without returning any rows. // Exec executes a query without returning any rows.
func (db *DB) Exec(query string, args ...interface{}) (Result, error) { func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
// Optional fast path, if the driver implements driver.Execer. sargs, err := subsetTypeArgs(args)
if execer, ok := db.driver.(driver.Execer); ok { if err != nil {
resi, err := execer.Exec(query, args) return nil, err
if err != nil {
return nil, err
}
return result{resi}, nil
} }
// If the driver does not implement driver.Execer, we need
// a connection.
ci, err := db.conn() ci, err := db.conn()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { ...@@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
defer db.putConn(ci) defer db.putConn(ci)
if execer, ok := ci.(driver.Execer); ok { if execer, ok := ci.(driver.Execer); ok {
resi, err := execer.Exec(query, args) resi, err := execer.Exec(query, sargs)
if err != nil { if err != driver.ErrSkip {
return nil, err if err != nil {
return nil, err
}
return result{resi}, nil
} }
return result{resi}, nil
} }
sti, err := ci.Prepare(query) sti, err := ci.Prepare(query)
...@@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { ...@@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
return nil, err return nil, err
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(args)
resi, err := sti.Exec(sargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { ...@@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
return nil, err return nil, err
} }
defer sti.Close() defer sti.Close()
resi, err := sti.Exec(args)
sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
resi, err := sti.Exec(sargs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
} }
defer releaseConn() defer releaseConn()
if want := si.NumInput(); len(args) != want { // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args))
} }
...@@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(args) != si.NumInput() {
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want := si.NumInput(); want != -1 && len(args) != want {
return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args)) return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args))
} }
rowsi, err := si.Query(args) sargs, err := subsetTypeArgs(args)
if err != nil {
return nil, err
}
rowsi, err := si.Query(sargs)
if err != nil { if err != nil {
s.db.putConn(ci) s.db.putConn(ci)
return nil, err return nil, err
......
...@@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) { ...@@ -34,8 +34,16 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
} }
} }
func closeDB(t *testing.T, db *DB) {
err := db.Close()
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
}
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db)
var name string var name string
var age int var age int
...@@ -69,6 +77,7 @@ func TestQuery(t *testing.T) { ...@@ -69,6 +77,7 @@ func TestQuery(t *testing.T) {
func TestStatementQueryRow(t *testing.T) { func TestStatementQueryRow(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db)
stmt, err := db.Prepare("SELECT|people|age|name=?") stmt, err := db.Prepare("SELECT|people|age|name=?")
if err != nil { if err != nil {
t.Fatalf("Prepare: %v", err) t.Fatalf("Prepare: %v", err)
...@@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) { ...@@ -94,6 +103,7 @@ func TestStatementQueryRow(t *testing.T) {
// just a test of fakedb itself // just a test of fakedb itself
func TestBogusPreboundParameters(t *testing.T) { func TestBogusPreboundParameters(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
_, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
if err == nil { if err == nil {
...@@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) { ...@@ -106,6 +116,7 @@ func TestBogusPreboundParameters(t *testing.T) {
func TestDb(t *testing.T) { func TestDb(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?") stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil { if err != nil {
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/rc4"
)
// streamDump is used to dump the initial keystream for stream ciphers. It is a
// a write-only buffer, and not intended for reading so do not require a mutex.
var streamDump [512]byte
// noneCipher implements cipher.Stream and provides no encryption. It is used
// by the transport before the first key-exchange.
type noneCipher struct{}
func (c noneCipher) XORKeyStream(dst, src []byte) {
copy(dst, src)
}
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewCTR(c, iv), nil
}
func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type cipherMode struct {
keySize int
ivSize int
skip int
createFn func(key, iv []byte) (cipher.Stream, error)
}
func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
}
// Specifies a default set of ciphers and a preference order. This is based on
// OpenSSH's default client preference order, minus algorithms that are not
// implemented.
var DefaultCipherOrder = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"arcfour256", "arcfour128",
}
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": &cipherMode{16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": &cipherMode{24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": &cipherMode{32, aes.BlockSize, 0, newAESCTR},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": &cipherMode{16, 0, 1536, newRC4},
"arcfour256": &cipherMode{32, 0, 1536, newRC4},
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"testing"
)
// TestCipherReversal tests that each cipher factory produces ciphers that can
// encrypt and decrypt some data successfully.
func TestCipherReversal(t *testing.T) {
testData := []byte("abcdefghijklmnopqrstuvwxyz012345")
testKey := []byte("AbCdEfGhIjKlMnOpQrStUvWxYz012345")
testIv := []byte("sdflkjhsadflkjhasdflkjhsadfklhsa")
cryptBuffer := make([]byte, 32)
for name, cipherMode := range cipherModes {
encrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create encrypter for %q: %s", name, err)
continue
}
decrypter, err := cipherMode.createCipher(testKey, testIv)
if err != nil {
t.Errorf("failed to create decrypter for %q: %s", name, err)
continue
}
copy(cryptBuffer, testData)
encrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if name == "none" {
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made change with 'none' cipher")
continue
}
} else {
if bytes.Equal(cryptBuffer, testData) {
t.Errorf("encryption made no change with %q", name)
continue
}
}
decrypter.XORKeyStream(cryptBuffer, cryptBuffer)
if !bytes.Equal(cryptBuffer, testData) {
t.Errorf("decrypted bytes not equal to input with %q", name)
continue
}
}
}
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range DefaultCipherOrder {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
}
}
}
...@@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { ...@@ -35,10 +35,6 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
conn.Close() conn.Close()
return nil, err return nil, err
} }
if err := conn.authenticate(); err != nil {
conn.Close()
return nil, err
}
go conn.mainLoop() go conn.mainLoop()
return conn, nil return conn, nil
} }
...@@ -64,8 +60,8 @@ func (c *ClientConn) handshake() error { ...@@ -64,8 +60,8 @@ func (c *ClientConn) handshake() error {
clientKexInit := kexInitMsg{ clientKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos, KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers, CiphersClientServer: c.config.Crypto.ciphers(),
CiphersServerClient: supportedCiphers, CiphersServerClient: c.config.Crypto.ciphers(),
MACsClientServer: supportedMACs, MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs, MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions, CompressionClientServer: supportedCompressions,
...@@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error { ...@@ -128,7 +124,10 @@ func (c *ClientConn) handshake() error {
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc) if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err
}
return c.authenticate(H)
} }
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
...@@ -195,6 +194,7 @@ func (c *ClientConn) openChan(typ string) (*clientChan, error) { ...@@ -195,6 +194,7 @@ func (c *ClientConn) openChan(typ string) (*clientChan, error) {
switch msg := (<-ch.msg).(type) { switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg: case *channelOpenConfirmMsg:
ch.peersId = msg.MyId ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg: case *channelOpenFailureMsg:
c.chanlist.remove(ch.id) c.chanlist.remove(ch.id)
return nil, errors.New(msg.Message) return nil, errors.New(msg.Message)
...@@ -301,6 +301,9 @@ type ClientConfig struct { ...@@ -301,6 +301,9 @@ type ClientConfig struct {
// A slice of ClientAuth methods. Only the first instance // A slice of ClientAuth methods. Only the first instance
// of a particular RFC 4252 method will be used during authentication. // of a particular RFC 4252 method will be used during authentication.
Auth []ClientAuth Auth []ClientAuth
// Cryptographic-related configuration.
Crypto CryptoConfig
} }
func (c *ClientConfig) rand() io.Reader { func (c *ClientConfig) rand() io.Reader {
......
...@@ -6,10 +6,11 @@ package ssh ...@@ -6,10 +6,11 @@ package ssh
import ( import (
"errors" "errors"
"io"
) )
// authenticate authenticates with the remote server. See RFC 4252. // authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate() error { func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session // initiate user auth session
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil { if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err return err
...@@ -26,7 +27,7 @@ func (c *ClientConn) authenticate() error { ...@@ -26,7 +27,7 @@ func (c *ClientConn) authenticate() error {
// then any untried methods suggested by the server. // then any untried methods suggested by the server.
tried, remain := make(map[string]bool), make(map[string]bool) tried, remain := make(map[string]bool), make(map[string]bool)
for auth := ClientAuth(new(noneAuth)); auth != nil; { for auth := ClientAuth(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.config.User, c.transport) ok, methods, err := auth.auth(session, c.config.User, c.transport, c.config.rand())
if err != nil { if err != nil {
return err return err
} }
...@@ -60,7 +61,7 @@ type ClientAuth interface { ...@@ -60,7 +61,7 @@ type ClientAuth interface {
// Returns true if authentication is successful. // Returns true if authentication is successful.
// If authentication is not successful, a []string of alternative // If authentication is not successful, a []string of alternative
// method names is returned. // method names is returned.
auth(user string, t *transport) (bool, []string, error) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error)
// method returns the RFC 4252 method name. // method returns the RFC 4252 method name.
method() string method() string
...@@ -69,7 +70,7 @@ type ClientAuth interface { ...@@ -69,7 +70,7 @@ type ClientAuth interface {
// "none" authentication, RFC 4252 section 5.2. // "none" authentication, RFC 4252 section 5.2.
type noneAuth int type noneAuth int
func (n *noneAuth) auth(user string, t *transport) (bool, []string, error) { func (n *noneAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{ if err := t.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: user, User: user,
Service: serviceSSH, Service: serviceSSH,
...@@ -102,7 +103,7 @@ type passwordAuth struct { ...@@ -102,7 +103,7 @@ type passwordAuth struct {
ClientPassword ClientPassword
} }
func (p *passwordAuth) auth(user string, t *transport) (bool, []string, error) { func (p *passwordAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type passwordAuthMsg struct { type passwordAuthMsg struct {
User string User string
Service string Service string
...@@ -155,3 +156,140 @@ type ClientPassword interface { ...@@ -155,3 +156,140 @@ type ClientPassword interface {
func ClientAuthPassword(impl ClientPassword) ClientAuth { func ClientAuthPassword(impl ClientPassword) ClientAuth {
return &passwordAuth{impl} return &passwordAuth{impl}
} }
// ClientKeyring implements access to a client key ring.
type ClientKeyring interface {
// Key returns the i'th rsa.Publickey or dsa.Publickey, or nil if
// no key exists at i.
Key(i int) (key interface{}, err error)
// Sign returns a signature of the given data using the i'th key
// and the supplied random source.
Sign(i int, rand io.Reader, data []byte) (sig []byte, err error)
}
// "publickey" authentication, RFC 4252 Section 7.
type publickeyAuth struct {
ClientKeyring
}
func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.Reader) (bool, []string, error) {
type publickeyAuthMsg struct {
User string
Service string
Method string
// HasSig indicates to the reciver packet that the auth request is signed and
// should be used for authentication of the request.
HasSig bool
Algoname string
Pubkey string
// Sig is defined as []byte so marshal will exclude it during the query phase
Sig []byte `ssh:"rest"`
}
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
var index int
// a map of public keys to their index in the keyring
validKeys := make(map[int]interface{})
for {
key, err := p.Key(index)
if err != nil {
return false, nil, err
}
if key == nil {
// no more keys in the keyring
break
}
pubkey := serializePublickey(key)
algoname := algoName(key)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: false,
Algoname: algoname,
Pubkey: string(pubkey),
}
if err := t.writePacket(marshal(msgUserAuthRequest, msg)); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthPubKeyOk:
msg := decode(packet).(*userAuthPubKeyOkMsg)
if msg.Algo != algoname || msg.PubKey != string(pubkey) {
continue
}
validKeys[index] = key
case msgUserAuthFailure:
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
index++
}
// methods that may continue if this auth is not successful.
var methods []string
for i, key := range validKeys {
pubkey := serializePublickey(key)
algoname := algoName(key)
sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
}, []byte(algoname), pubkey))
if err != nil {
return false, nil, err
}
// manually wrap the serialized signature in a string
s := serializeSignature(algoname, sign)
sig := make([]byte, stringLength(s))
marshalString(sig, s)
msg := publickeyAuthMsg{
User: user,
Service: serviceSSH,
Method: p.method(),
HasSig: true,
Algoname: algoname,
Pubkey: string(pubkey),
Sig: sig,
}
p := marshal(msgUserAuthRequest, msg)
if err := t.writePacket(p); err != nil {
return false, nil, err
}
packet, err := t.readPacket()
if err != nil {
return false, nil, err
}
switch packet[0] {
case msgUserAuthSuccess:
return true, nil, nil
case msgUserAuthFailure:
msg := decode(packet).(*userAuthFailureMsg)
methods = msg.Methods
continue
case msgDisconnect:
return false, nil, io.EOF
default:
return false, nil, UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
}
return false, methods, nil
}
func (p *publickeyAuth) method() string {
return "publickey"
}
// ClientAuthPublickey returns a ClientAuth using public key authentication.
func ClientAuthPublickey(impl ClientKeyring) ClientAuth {
return &publickeyAuth{impl}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"io/ioutil"
"testing"
)
const _pem = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA19lGVsTqIT5iiNYRgnoY1CwkbETW5cq+Rzk5v/kTlf31XpSU
70HVWkbTERECjaYdXM2gGcbb+sxpq6GtXf1M3kVomycqhxwhPv4Cr6Xp4WT/jkFx
9z+FFzpeodGJWjOH6L2H5uX1Cvr9EDdQp9t9/J32/qBFntY8GwoUI/y/1MSTmMiF
tupdMODN064vd3gyMKTwrlQ8tZM6aYuyOPsutLlUY7M5x5FwMDYvnPDSeyT/Iw0z
s3B+NCyqeeMd2T7YzQFnRATj0M7rM5LoSs7DVqVriOEABssFyLj31PboaoLhOKgc
qoM9khkNzr7FHVvi+DhYM2jD0DwvqZLN6NmnLwIDAQABAoIBAQCGVj+kuSFOV1lT
+IclQYA6bM6uY5mroqcSBNegVxCNhWU03BxlW//BE9tA/+kq53vWylMeN9mpGZea
riEMIh25KFGWXqXlOOioH8bkMsqA8S7sBmc7jljyv+0toQ9vCCtJ+sueNPhxQQxH
D2YvUjfzBQ04I9+wn30BByDJ1QA/FoPsunxIOUCcRBE/7jxuLYcpR+JvEF68yYIh
atXRld4W4in7T65YDR8jK1Uj9XAcNeDYNpT/M6oFLx1aPIlkG86aCWRO19S1jLPT
b1ZAKHHxPMCVkSYW0RqvIgLXQOR62D0Zne6/2wtzJkk5UCjkSQ2z7ZzJpMkWgDgN
ifCULFPBAoGBAPoMZ5q1w+zB+knXUD33n1J+niN6TZHJulpf2w5zsW+m2K6Zn62M
MXndXlVAHtk6p02q9kxHdgov34Uo8VpuNjbS1+abGFTI8NZgFo+bsDxJdItemwC4
KJ7L1iz39hRN/ZylMRLz5uTYRGddCkeIHhiG2h7zohH/MaYzUacXEEy3AoGBANz8
e/msleB+iXC0cXKwds26N4hyMdAFE5qAqJXvV3S2W8JZnmU+sS7vPAWMYPlERPk1
D8Q2eXqdPIkAWBhrx4RxD7rNc5qFNcQWEhCIxC9fccluH1y5g2M+4jpMX2CT8Uv+
3z+NoJ5uDTXZTnLCfoZzgZ4nCZVZ+6iU5U1+YXFJAoGBANLPpIV920n/nJmmquMj
orI1R/QXR9Cy56cMC65agezlGOfTYxk5Cfl5Ve+/2IJCfgzwJyjWUsFx7RviEeGw
64o7JoUom1HX+5xxdHPsyZ96OoTJ5RqtKKoApnhRMamau0fWydH1yeOEJd+TRHhc
XStGfhz8QNa1dVFvENczja1vAoGABGWhsd4VPVpHMc7lUvrf4kgKQtTC2PjA4xoc
QJ96hf/642sVE76jl+N6tkGMzGjnVm4P2j+bOy1VvwQavKGoXqJBRd5Apppv727g
/SM7hBXKFc/zH80xKBBgP/i1DR7kdjakCoeu4ngeGywvu2jTS6mQsqzkK+yWbUxJ
I7mYBsECgYB/KNXlTEpXtz/kwWCHFSYA8U74l7zZbVD8ul0e56JDK+lLcJ0tJffk
gqnBycHj6AhEycjda75cs+0zybZvN4x65KZHOGW/O/7OAWEcZP5TPb3zf9ned3Hl
NsZoFj52ponUM6+99A2CmezFCN16c4mbA//luWF+k3VVqR6BpkrhKw==
-----END RSA PRIVATE KEY-----`
// reused internally by tests
var serverConfig = new(ServerConfig)
func init() {
if err := serverConfig.SetRSAPrivateKey([]byte(_pem)); err != nil {
panic("unable to set private key: " + err.Error())
}
}
// keychain implements the ClientPublickey interface
type keychain struct {
keys []*rsa.PrivateKey
}
func (k *keychain) Key(i int) (interface{}, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
return k.keys[i].PublicKey, nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
hashFunc := crypto.SHA1
h := hashFunc.New()
h.Write(data)
digest := h.Sum()
return rsa.SignPKCS1v15(rand, k.keys[i], hashFunc, digest)
}
func (k *keychain) loadPEM(file string) error {
buf, err := ioutil.ReadFile(file)
if err != nil {
return err
}
block, _ := pem.Decode(buf)
if block == nil {
return errors.New("ssh: no key found")
}
r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
k.keys = append(k.keys, r)
return nil
}
var pkey *rsa.PrivateKey
func init() {
var err error
pkey, err = rsa.GenerateKey(rand.Reader, 512)
if err != nil {
panic("unable to generate public key")
}
}
func TestClientAuthPublickey(t *testing.T) {
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
serverConfig.PasswordCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool, 1)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
if err := c.Handshake(); err != nil {
t.Error(err)
}
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
// password implements the ClientPassword interface
type password string
func (p password) Password(user string) (string, error) {
return string(p), nil
}
func TestClientAuthPassword(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
serverConfig.PubKeyCallback = nil
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(pw),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
func TestClientAuthPasswordAndPublickey(t *testing.T) {
pw := password("tiger")
serverConfig.PasswordCallback = func(user, pass string) bool {
return user == "testuser" && pass == string(pw)
}
k := new(keychain)
k.keys = append(k.keys, pkey)
serverConfig.PubKeyCallback = func(user, algo string, pubkey []byte) bool {
expected := []byte(serializePublickey(k.keys[0].PublicKey))
algoname := algoName(k.keys[0].PublicKey)
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
}
l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
if err != nil {
t.Fatalf("unable to listen: %s", err)
}
defer l.Close()
done := make(chan bool)
go func() {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
if err := c.Handshake(); err != nil {
t.Error(err)
}
defer c.Close()
done <- true
}()
wrongPw := password("wrong")
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
ClientAuthPassword(wrongPw),
ClientAuthPublickey(k),
},
}
c, err := Dial("tcp", l.Addr().String(), config)
if err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
defer c.Close()
<-done
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// ClientConn functional tests.
// These tests require a running ssh server listening on port 22
// on the local host. Functional tests will be skipped unless
// -ssh.user and -ssh.pass must be passed to gotest.
import (
"flag"
"testing"
)
var (
sshuser = flag.String("ssh.user", "", "ssh username")
sshpass = flag.String("ssh.pass", "", "ssh password")
sshprivkey = flag.String("ssh.privkey", "", "ssh privkey file")
)
func TestFuncPasswordAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPassword(password(*sshpass)),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("Unable to connect: %s", err)
}
defer conn.Close()
}
func TestFuncPublickeyAuth(t *testing.T) {
if *sshuser == "" {
t.Log("ssh.user not defined, skipping test")
return
}
kc := new(keychain)
if err := kc.loadPEM(*sshprivkey); err != nil {
t.Fatalf("unable to load private key: %s", err)
}
config := &ClientConfig{
User: *sshuser,
Auth: []ClientAuth{
ClientAuthPublickey(kc),
},
}
conn, err := Dial("tcp", "localhost:22", config)
if err != nil {
t.Fatalf("unable to connect: %s", err)
}
defer conn.Close()
}
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
package ssh package ssh
import ( import (
"crypto/dsa"
"crypto/rsa"
"math/big" "math/big"
"strconv" "strconv"
"sync" "sync"
...@@ -14,7 +16,6 @@ import ( ...@@ -14,7 +16,6 @@ import (
const ( const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa" hostAlgoRSA = "ssh-rsa"
cipherAES128CTR = "aes128-ctr"
macSHA196 = "hmac-sha1-96" macSHA196 = "hmac-sha1-96"
compressionNone = "none" compressionNone = "none"
serviceUserAuth = "ssh-userauth" serviceUserAuth = "ssh-userauth"
...@@ -23,7 +24,6 @@ const ( ...@@ -23,7 +24,6 @@ const (
var supportedKexAlgos = []string{kexAlgoDH14SHA1} var supportedKexAlgos = []string{kexAlgoDH14SHA1}
var supportedHostKeyAlgos = []string{hostAlgoRSA} var supportedHostKeyAlgos = []string{hostAlgoRSA}
var supportedCiphers = []string{cipherAES128CTR}
var supportedMACs = []string{macSHA196} var supportedMACs = []string{macSHA196}
var supportedCompressions = []string{compressionNone} var supportedCompressions = []string{compressionNone}
...@@ -127,3 +127,100 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke ...@@ -127,3 +127,100 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
ok = true ok = true
return return
} }
// Cryptographic configuration common to both ServerConfig and ClientConfig.
type CryptoConfig struct {
// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
// used.
Ciphers []string
}
func (c *CryptoConfig) ciphers() []string {
if c.Ciphers == nil {
return DefaultCipherOrder
}
return c.Ciphers
}
// serialize a signed slice according to RFC 4254 6.6.
func serializeSignature(algoname string, sig []byte) []byte {
length := stringLength([]byte(algoname))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalString(r, sig)
return ret
}
// serialize an rsa.PublicKey or dsa.PublicKey according to RFC 4253 6.6.
func serializePublickey(key interface{}) []byte {
algoname := algoName(key)
switch key := key.(type) {
case rsa.PublicKey:
e := new(big.Int).SetInt64(int64(key.E))
length := stringLength([]byte(algoname))
length += intLength(e)
length += intLength(key.N)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, e)
marshalInt(r, key.N)
return ret
case dsa.PublicKey:
length := stringLength([]byte(algoname))
length += intLength(key.P)
length += intLength(key.Q)
length += intLength(key.G)
length += intLength(key.Y)
ret := make([]byte, length)
r := marshalString(ret, []byte(algoname))
r = marshalInt(r, key.P)
r = marshalInt(r, key.Q)
r = marshalInt(r, key.G)
marshalInt(r, key.Y)
return ret
}
panic("unexpected key type")
}
func algoName(key interface{}) string {
switch key.(type) {
case rsa.PublicKey:
return "ssh-rsa"
case dsa.PublicKey:
return "ssh-dss"
}
panic("unexpected key type")
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
...@@ -392,7 +392,10 @@ func parseString(in []byte) (out, rest []byte, ok bool) { ...@@ -392,7 +392,10 @@ func parseString(in []byte) (out, rest []byte, ok bool) {
return return
} }
var comma = []byte{','} var (
comma = []byte{','}
emptyNameList = []string{}
)
func parseNameList(in []byte) (out []string, rest []byte, ok bool) { func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
contents, rest, ok := parseString(in) contents, rest, ok := parseString(in)
...@@ -400,6 +403,7 @@ func parseNameList(in []byte) (out []string, rest []byte, ok bool) { ...@@ -400,6 +403,7 @@ func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
return return
} }
if len(contents) == 0 { if len(contents) == 0 {
out = emptyNameList
return return
} }
parts := bytes.Split(contents, comma) parts := bytes.Split(contents, comma)
...@@ -444,8 +448,6 @@ func parseUint32(in []byte) (out uint32, rest []byte, ok bool) { ...@@ -444,8 +448,6 @@ func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
return return
} }
const maxPacketSize = 36000
func nameListLength(namelist []string) int { func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */ length := 4 /* uint32 length prefix */
for i, name := range namelist { for i, name := range namelist {
......
...@@ -40,6 +40,9 @@ type ServerConfig struct { ...@@ -40,6 +40,9 @@ type ServerConfig struct {
// key authentication. It must return true iff the given public key is // key authentication. It must return true iff the given public key is
// valid for the given user. // valid for the given user.
PubKeyCallback func(user, algo string, pubkey []byte) bool PubKeyCallback func(user, algo string, pubkey []byte) bool
// Cryptographic-related configuration.
Crypto CryptoConfig
} }
func (c *ServerConfig) rand() io.Reader { func (c *ServerConfig) rand() io.Reader {
...@@ -221,7 +224,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha ...@@ -221,7 +224,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return nil, nil, errors.New("internal error") return nil, nil, errors.New("internal error")
} }
serializedSig := serializeRSASignature(sig) serializedSig := serializeSignature(hostAlgoRSA, sig)
kexDHReply := kexDHReplyMsg{ kexDHReply := kexDHReplyMsg{
HostKey: serializedHostKey, HostKey: serializedHostKey,
...@@ -234,50 +237,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha ...@@ -234,50 +237,9 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
return return
} }
func serializeRSASignature(sig []byte) []byte {
length := stringLength([]byte(hostAlgoRSA))
length += stringLength(sig)
ret := make([]byte, length)
r := marshalString(ret, []byte(hostAlgoRSA))
r = marshalString(r, sig)
return ret
}
// serverVersion is the fixed identification string that Server will use. // serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n") var serverVersion = []byte("SSH-2.0-Go\r\n")
// buildDataSignedForAuth returns the data that is signed in order to prove
// posession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
user := []byte(req.User)
service := []byte(req.Service)
method := []byte(req.Method)
length := stringLength(sessionId)
length += 1
length += stringLength(user)
length += stringLength(service)
length += stringLength(method)
length += 1
length += stringLength(algo)
length += stringLength(pubKey)
ret := make([]byte, length)
r := marshalString(ret, sessionId)
r[0] = msgUserAuthRequest
r = r[1:]
r = marshalString(r, user)
r = marshalString(r, service)
r = marshalString(r, method)
r[0] = 1
r = r[1:]
r = marshalString(r, algo)
r = marshalString(r, pubKey)
return ret
}
// Handshake performs an SSH transport and client authentication on the given ServerConn. // Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error { func (s *ServerConn) Handshake() error {
var magics handshakeMagics var magics handshakeMagics
...@@ -298,8 +260,8 @@ func (s *ServerConn) Handshake() error { ...@@ -298,8 +260,8 @@ func (s *ServerConn) Handshake() error {
serverKexInit := kexInitMsg{ serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos, KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos, ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers, CiphersClientServer: s.config.Crypto.ciphers(),
CiphersServerClient: supportedCiphers, CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: supportedMACs, MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs, MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions, CompressionClientServer: supportedCompressions,
...@@ -364,7 +326,9 @@ func (s *ServerConn) Handshake() error { ...@@ -364,7 +326,9 @@ func (s *ServerConn) Handshake() error {
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
return err
}
if packet, err = s.readPacket(); err != nil { if packet, err = s.readPacket(); err != nil {
return err return err
} }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"errors"
"io"
"net"
)
// Dial initiates a connection to the addr from the remote host.
// addr is resolved using net.ResolveTCPAddr before connection.
// This could allow an observer to observe the DNS name of the
// remote host. Consider using ssh.DialTCP to avoid this.
func (c *ClientConn) Dial(n, addr string) (net.Conn, error) {
raddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.DialTCP(n, nil, raddr)
}
// DialTCP connects to the remote address raddr on the network net,
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
// as the local address for the connection.
func (c *ClientConn) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
if laddr == nil {
laddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
if err != nil {
return nil, err
}
return &tcpchanconn{
tcpchan: ch,
laddr: laddr,
raddr: raddr,
}, nil
}
// dial opens a direct-tcpip connection to the remote server. laddr and raddr are passed as
// strings and are expected to be resolveable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpchan, error) {
// RFC 4254 7.2
type channelOpenDirectMsg struct {
ChanType string
PeersId uint32
PeersWindow uint32
MaxPacketSize uint32
raddr string
rport uint32
laddr string
lport uint32
}
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
PeersId: ch.id,
PeersWindow: 1 << 14,
MaxPacketSize: 1 << 15, // RFC 4253 6.1
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
})); err != nil {
c.chanlist.remove(ch.id)
return nil, err
}
// wait for response
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
ch.win <- int(msg.MyWindow)
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: error opening remote TCP connection: " + msg.Message)
default:
c.chanlist.remove(ch.id)
return nil, errors.New("ssh: unexpected packet")
}
return &tcpchan{
clientChan: ch,
Reader: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.data,
},
Writer: &chanWriter{
packetWriter: ch,
id: ch.id,
win: ch.win,
},
}, nil
}
type tcpchan struct {
*clientChan // the backing channel
io.Reader
io.Writer
}
// tcpchanconn fulfills the net.Conn interface without
// the tcpchan having to hold laddr or raddr directly.
type tcpchanconn struct {
*tcpchan
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpchanconn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpchanconn) RemoteAddr() net.Addr {
return t.raddr
}
// SetTimeout sets the read and write deadlines associated
// with the connection.
func (t *tcpchanconn) SetTimeout(nsec int64) error {
if err := t.SetReadTimeout(nsec); err != nil {
return err
}
return t.SetWriteTimeout(nsec)
}
// SetReadTimeout sets the time (in nanoseconds) that
// Read will wait for data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
func (t *tcpchanconn) SetReadTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}
// SetWriteTimeout sets the time (in nanoseconds) that
// Write will wait to send its data before returning an error with Timeout() == true.
// Setting nsec == 0 (the default) disables the deadline.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
func (t *tcpchanconn) SetWriteTimeout(nsec int64) error {
return errors.New("ssh: tcpchan: timeout not supported")
}
...@@ -7,7 +7,6 @@ package ssh ...@@ -7,7 +7,6 @@ package ssh
import ( import (
"bufio" "bufio"
"crypto" "crypto"
"crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/subtle" "crypto/subtle"
...@@ -19,7 +18,10 @@ import ( ...@@ -19,7 +18,10 @@ import (
) )
const ( const (
paddingMultiple = 16 // TODO(dfc) does this need to be configurable? packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
minPacketSize = 16
maxPacketSize = 36000
minPaddingSize = 4 // TODO(huin) should this be configurable?
) )
// filteredConn reduces the set of methods exposed when embeddeding // filteredConn reduces the set of methods exposed when embeddeding
...@@ -61,8 +63,7 @@ type reader struct { ...@@ -61,8 +63,7 @@ type reader struct {
type writer struct { type writer struct {
*sync.Mutex // protects writer.Writer from concurrent writes *sync.Mutex // protects writer.Writer from concurrent writes
*bufio.Writer *bufio.Writer
paddingMultiple int rand io.Reader
rand io.Reader
common common
} }
...@@ -82,14 +83,11 @@ type common struct { ...@@ -82,14 +83,11 @@ type common struct {
func (r *reader) readOnePacket() ([]byte, error) { func (r *reader) readOnePacket() ([]byte, error) {
var lengthBytes = make([]byte, 5) var lengthBytes = make([]byte, 5)
var macSize uint32 var macSize uint32
if _, err := io.ReadFull(r, lengthBytes); err != nil { if _, err := io.ReadFull(r, lengthBytes); err != nil {
return nil, err return nil, err
} }
if r.cipher != nil { r.cipher.XORKeyStream(lengthBytes, lengthBytes)
r.cipher.XORKeyStream(lengthBytes, lengthBytes)
}
if r.mac != nil { if r.mac != nil {
r.mac.Reset() r.mac.Reset()
...@@ -153,9 +151,9 @@ func (w *writer) writePacket(packet []byte) error { ...@@ -153,9 +151,9 @@ func (w *writer) writePacket(packet []byte) error {
w.Mutex.Lock() w.Mutex.Lock()
defer w.Mutex.Unlock() defer w.Mutex.Unlock()
paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple paddingLength := packetSizeMultiple - (5+len(packet))%packetSizeMultiple
if paddingLength < 4 { if paddingLength < 4 {
paddingLength += paddingMultiple paddingLength += packetSizeMultiple
} }
length := len(packet) + 1 + paddingLength length := len(packet) + 1 + paddingLength
...@@ -188,11 +186,9 @@ func (w *writer) writePacket(packet []byte) error { ...@@ -188,11 +186,9 @@ func (w *writer) writePacket(packet []byte) error {
// TODO(dfc) lengthBytes, packet and padding should be // TODO(dfc) lengthBytes, packet and padding should be
// subslices of a single buffer // subslices of a single buffer
if w.cipher != nil { w.cipher.XORKeyStream(lengthBytes, lengthBytes)
w.cipher.XORKeyStream(lengthBytes, lengthBytes) w.cipher.XORKeyStream(packet, packet)
w.cipher.XORKeyStream(packet, packet) w.cipher.XORKeyStream(padding, padding)
w.cipher.XORKeyStream(padding, padding)
}
if _, err := w.Write(lengthBytes); err != nil { if _, err := w.Write(lengthBytes); err != nil {
return err return err
...@@ -227,11 +223,17 @@ func newTransport(conn net.Conn, rand io.Reader) *transport { ...@@ -227,11 +223,17 @@ func newTransport(conn net.Conn, rand io.Reader) *transport {
return &transport{ return &transport{
reader: reader{ reader: reader{
Reader: bufio.NewReader(conn), Reader: bufio.NewReader(conn),
common: common{
cipher: noneCipher{},
},
}, },
writer: writer{ writer: writer{
Writer: bufio.NewWriter(conn), Writer: bufio.NewWriter(conn),
rand: rand, rand: rand,
Mutex: new(sync.Mutex), Mutex: new(sync.Mutex),
common: common{
cipher: noneCipher{},
},
}, },
filteredConn: conn, filteredConn: conn,
} }
...@@ -249,29 +251,32 @@ var ( ...@@ -249,29 +251,32 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
) )
// setupKeys sets the cipher and MAC keys from K, H and sessionId, as // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys // described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys). // (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error { func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
h := hashFunc.New() cipherMode := cipherModes[c.cipherAlgo]
blockSize := 16
keySize := 16
macKeySize := 20 macKeySize := 20
iv := make([]byte, blockSize) iv := make([]byte, cipherMode.ivSize)
key := make([]byte, keySize) key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macKeySize) macKey := make([]byte, macKeySize)
h := hashFunc.New()
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h) generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h) generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h) generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)} c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
cipher, err := cipherMode.createCipher(key, iv)
if err != nil { if err != nil {
return err return err
} }
c.cipher = cipher.NewCTR(aes, iv)
c.cipher = cipher
return nil return nil
} }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import "io"
// Shell contains the state for running a VT100 terminal that is capable of
// reading lines of input.
type Shell struct {
c io.ReadWriter
prompt string
// line is the current line being entered.
line []byte
// pos is the logical position of the cursor in line
pos int
// cursorX contains the current X value of the cursor where the left
// edge is 0. cursorY contains the row number where the first row of
// the current line is 0.
cursorX, cursorY int
// maxLine is the greatest value of cursorY so far.
maxLine int
termWidth, termHeight int
// outBuf contains the terminal data to be sent.
outBuf []byte
// remainder contains the remainder of any partial key sequences after
// a read. It aliases into inBuf.
remainder []byte
inBuf [256]byte
}
// NewShell runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
// a local terminal, that terminal must first have been put into raw mode.
// prompt is a string that is written at the start of each input line (i.e.
// "> ").
func NewShell(c io.ReadWriter, prompt string) *Shell {
return &Shell{
c: c,
prompt: prompt,
termWidth: 80,
termHeight: 24,
}
}
const (
keyCtrlD = 4
keyEnter = '\r'
keyEscape = 27
keyBackspace = 127
keyUnknown = 256 + iota
keyUp
keyDown
keyLeft
keyRight
keyAltLeft
keyAltRight
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns -1.
func bytesToKey(b []byte) (int, []byte) {
if len(b) == 0 {
return -1, nil
}
if b[0] != keyEscape {
return int(b[0]), b[1:]
}
if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
switch b[2] {
case 'A':
return keyUp, b[3:]
case 'B':
return keyDown, b[3:]
case 'C':
return keyRight, b[3:]
case 'D':
return keyLeft, b[3:]
}
}
if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
switch b[5] {
case 'C':
return keyAltRight, b[6:]
case 'D':
return keyAltLeft, b[6:]
}
}
// If we get here then we have a key that we don't recognise, or a
// partial sequence. It's not clear how one should find the end of a
// sequence without knowing them all, but it seems that [a-zA-Z] only
// appears at the end of a sequence.
for i, c := range b[0:] {
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
return keyUnknown, b[i+1:]
}
}
return -1, b
}
// queue appends data to the end of ss.outBuf
func (ss *Shell) queue(data []byte) {
if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
copy(newOutBuf, ss.outBuf)
ss.outBuf = newOutBuf
}
oldLen := len(ss.outBuf)
ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
copy(ss.outBuf[oldLen:], data)
}
var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
func isPrintable(key int) bool {
return key >= 32 && key < 127
}
// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
// given, logical position in the text.
func (ss *Shell) moveCursorToPos(pos int) {
x := len(ss.prompt) + pos
y := x / ss.termWidth
x = x % ss.termWidth
up := 0
if y < ss.cursorY {
up = ss.cursorY - y
}
down := 0
if y > ss.cursorY {
down = y - ss.cursorY
}
left := 0
if x < ss.cursorX {
left = ss.cursorX - x
}
right := 0
if x > ss.cursorX {
right = x - ss.cursorX
}
movement := make([]byte, 3*(up+down+left+right))
m := movement
for i := 0; i < up; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'A'
m = m[3:]
}
for i := 0; i < down; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'B'
m = m[3:]
}
for i := 0; i < left; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'D'
m = m[3:]
}
for i := 0; i < right; i++ {
m[0] = keyEscape
m[1] = '['
m[2] = 'C'
m = m[3:]
}
ss.cursorX = x
ss.cursorY = y
ss.queue(movement)
}
const maxLineLength = 4096
// handleKey processes the given key and, optionally, returns a line of text
// that the user has entered.
func (ss *Shell) handleKey(key int) (line string, ok bool) {
switch key {
case keyBackspace:
if ss.pos == 0 {
return
}
ss.pos--
copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
ss.line = ss.line[:len(ss.line)-1]
ss.writeLine(ss.line[ss.pos:])
ss.moveCursorToPos(ss.pos)
ss.queue(eraseUnderCursor)
case keyAltLeft:
// move left by a word.
if ss.pos == 0 {
return
}
ss.pos--
for ss.pos > 0 {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos--
}
for ss.pos > 0 {
if ss.line[ss.pos] == ' ' {
ss.pos++
break
}
ss.pos--
}
ss.moveCursorToPos(ss.pos)
case keyAltRight:
// move right by a word.
for ss.pos < len(ss.line) {
if ss.line[ss.pos] == ' ' {
break
}
ss.pos++
}
for ss.pos < len(ss.line) {
if ss.line[ss.pos] != ' ' {
break
}
ss.pos++
}
ss.moveCursorToPos(ss.pos)
case keyLeft:
if ss.pos == 0 {
return
}
ss.pos--
ss.moveCursorToPos(ss.pos)
case keyRight:
if ss.pos == len(ss.line) {
return
}
ss.pos++
ss.moveCursorToPos(ss.pos)
case keyEnter:
ss.moveCursorToPos(len(ss.line))
ss.queue([]byte("\r\n"))
line = string(ss.line)
ok = true
ss.line = ss.line[:0]
ss.pos = 0
ss.cursorX = 0
ss.cursorY = 0
ss.maxLine = 0
default:
if !isPrintable(key) {
return
}
if len(ss.line) == maxLineLength {
return
}
if len(ss.line) == cap(ss.line) {
newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
copy(newLine, ss.line)
ss.line = newLine
}
ss.line = ss.line[:len(ss.line)+1]
copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
ss.line[ss.pos] = byte(key)
ss.writeLine(ss.line[ss.pos:])
ss.pos++
ss.moveCursorToPos(ss.pos)
}
return
}
func (ss *Shell) writeLine(line []byte) {
for len(line) != 0 {
if ss.cursorX == ss.termWidth {
ss.queue([]byte("\r\n"))
ss.cursorX = 0
ss.cursorY++
if ss.cursorY > ss.maxLine {
ss.maxLine = ss.cursorY
}
}
remainingOnLine := ss.termWidth - ss.cursorX
todo := len(line)
if todo > remainingOnLine {
todo = remainingOnLine
}
ss.queue(line[:todo])
ss.cursorX += todo
line = line[todo:]
}
}
func (ss *Shell) Write(buf []byte) (n int, err error) {
return ss.c.Write(buf)
}
// ReadLine returns a line of input from the terminal.
func (ss *Shell) ReadLine() (line string, err error) {
ss.writeLine([]byte(ss.prompt))
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
for {
// ss.remainder is a slice at the beginning of ss.inBuf
// containing a partial key sequence
readBuf := ss.inBuf[len(ss.remainder):]
var n int
n, err = ss.c.Read(readBuf)
if err != nil {
return
}
if err == nil {
ss.remainder = ss.inBuf[:n+len(ss.remainder)]
rest := ss.remainder
lineOk := false
for !lineOk {
var key int
key, rest = bytesToKey(rest)
if key < 0 {
break
}
if key == keyCtrlD {
return "", io.EOF
}
line, lineOk = ss.handleKey(key)
}
if len(rest) > 0 {
n := copy(ss.inBuf[:], rest)
ss.remainder = ss.inBuf[:n]
} else {
ss.remainder = nil
}
ss.c.Write(ss.outBuf)
ss.outBuf = ss.outBuf[:0]
if lineOk {
return
}
continue
}
}
panic("unreachable")
}
...@@ -41,7 +41,7 @@ func (c *MockTerminal) Write(data []byte) (n int, err error) { ...@@ -41,7 +41,7 @@ func (c *MockTerminal) Write(data []byte) (n int, err error) {
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
c := &MockTerminal{} c := &MockTerminal{}
ss := NewShell(c, "> ") ss := NewTerminal(c, "> ")
line, err := ss.ReadLine() line, err := ss.ReadLine()
if line != "" { if line != "" {
t.Errorf("Expected empty line but got: %s", line) t.Errorf("Expected empty line but got: %s", line)
...@@ -95,7 +95,7 @@ func TestKeyPresses(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestKeyPresses(t *testing.T) {
toSend: []byte(test.in), toSend: []byte(test.in),
bytesPerRead: j, bytesPerRead: j,
} }
ss := NewShell(c, "> ") ss := NewTerminal(c, "> ")
line, err := ss.ReadLine() line, err := ss.ReadLine()
if line != test.line { if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err.String())
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"io"
"syscall"
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
err := syscall.Tcgetattr(fd, &termios)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if err := syscall.Tcgetattr(fd, &oldState.termios); err != nil {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
err := syscall.Tcsetattr(fd, syscall.TCSANOW, &state.termios)
return err
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if err := syscall.Tcgetattr(fd, &oldState); err != nil {
return nil, err
}
newState := oldState
newState.Lflag &^= syscall.ECHO
if err := syscall.Tcsetattr(fd, syscall.TCSANOW, &newState); err != nil {
return nil, err
}
defer func() {
syscall.Tcsetattr(fd, syscall.TCSANOW, &oldState)
}()
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}
...@@ -357,6 +357,10 @@ var fmttests = []struct { ...@@ -357,6 +357,10 @@ var fmttests = []struct {
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`}, {"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`}, {"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
{"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`}, {"%#v", SI{}, `fmt_test.SI{I:interface {}(nil)}`},
{"%#v", []int(nil), `[]int(nil)`},
{"%#v", []int{}, `[]int{}`},
{"%#v", map[int]byte(nil), `map[int] uint8(nil)`},
{"%#v", map[int]byte{}, `map[int] uint8{}`},
// slices with other formats // slices with other formats
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`}, {"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},
......
...@@ -795,6 +795,10 @@ BigSwitch: ...@@ -795,6 +795,10 @@ BigSwitch:
case reflect.Map: case reflect.Map:
if goSyntax { if goSyntax {
p.buf.WriteString(f.Type().String()) p.buf.WriteString(f.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{') p.buf.WriteByte('{')
} else { } else {
p.buf.Write(mapBytes) p.buf.Write(mapBytes)
...@@ -873,6 +877,10 @@ BigSwitch: ...@@ -873,6 +877,10 @@ BigSwitch:
} }
if goSyntax { if goSyntax {
p.buf.WriteString(value.Type().String()) p.buf.WriteString(value.Type().String())
if f.IsNil() {
p.buf.WriteString("(nil)")
break
}
p.buf.WriteByte('{') p.buf.WriteByte('{')
} else { } else {
p.buf.WriteByte('[') p.buf.WriteByte('[')
......
...@@ -324,7 +324,7 @@ var x, y Xs ...@@ -324,7 +324,7 @@ var x, y Xs
var z IntString var z IntString
var multiTests = []ScanfMultiTest{ var multiTests = []ScanfMultiTest{
{"", "", nil, nil, ""}, {"", "", []interface{}{}, []interface{}{}, ""},
{"%d", "23", args(&i), args(23), ""}, {"%d", "23", args(&i), args(23), ""},
{"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""}, {"%2s%3s", "22333", args(&s, &t), args("22", "333"), ""},
{"%2d%3d", "44555", args(&i, &j), args(44, 555), ""}, {"%2d%3d", "44555", args(&i, &j), args(44, 555), ""},
...@@ -378,7 +378,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{} ...@@ -378,7 +378,7 @@ func testScan(name string, t *testing.T, scan func(r io.Reader, a ...interface{}
} }
val := v.Interface() val := v.Interface()
if !reflect.DeepEqual(val, test.out) { if !reflect.DeepEqual(val, test.out) {
t.Errorf("%s scanning %q: expected %v got %v, type %T", name, test.text, test.out, val, val) t.Errorf("%s scanning %q: expected %#v got %#v, type %T", name, test.text, test.out, val, val)
} }
} }
} }
...@@ -417,7 +417,7 @@ func TestScanf(t *testing.T) { ...@@ -417,7 +417,7 @@ func TestScanf(t *testing.T) {
} }
val := v.Interface() val := v.Interface()
if !reflect.DeepEqual(val, test.out) { if !reflect.DeepEqual(val, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v, type %T", test.format, test.text, test.out, val, val) t.Errorf("scanning (%q, %q): expected %#v got %#v, type %T", test.format, test.text, test.out, val, val)
} }
} }
} }
...@@ -520,7 +520,7 @@ func testScanfMulti(name string, t *testing.T) { ...@@ -520,7 +520,7 @@ func testScanfMulti(name string, t *testing.T) {
} }
result := resultVal.Interface() result := resultVal.Interface()
if !reflect.DeepEqual(result, test.out) { if !reflect.DeepEqual(result, test.out) {
t.Errorf("scanning (%q, %q): expected %v got %v", test.format, test.text, test.out, result) t.Errorf("scanning (%q, %q): expected %#v got %#v", test.format, test.text, test.out, result)
} }
} }
} }
......
...@@ -412,29 +412,29 @@ func (x *ChanType) End() token.Pos { return x.Value.End() } ...@@ -412,29 +412,29 @@ func (x *ChanType) End() token.Pos { return x.Value.End() }
// exprNode() ensures that only expression/type nodes can be // exprNode() ensures that only expression/type nodes can be
// assigned to an ExprNode. // assigned to an ExprNode.
// //
func (x *BadExpr) exprNode() {} func (*BadExpr) exprNode() {}
func (x *Ident) exprNode() {} func (*Ident) exprNode() {}
func (x *Ellipsis) exprNode() {} func (*Ellipsis) exprNode() {}
func (x *BasicLit) exprNode() {} func (*BasicLit) exprNode() {}
func (x *FuncLit) exprNode() {} func (*FuncLit) exprNode() {}
func (x *CompositeLit) exprNode() {} func (*CompositeLit) exprNode() {}
func (x *ParenExpr) exprNode() {} func (*ParenExpr) exprNode() {}
func (x *SelectorExpr) exprNode() {} func (*SelectorExpr) exprNode() {}
func (x *IndexExpr) exprNode() {} func (*IndexExpr) exprNode() {}
func (x *SliceExpr) exprNode() {} func (*SliceExpr) exprNode() {}
func (x *TypeAssertExpr) exprNode() {} func (*TypeAssertExpr) exprNode() {}
func (x *CallExpr) exprNode() {} func (*CallExpr) exprNode() {}
func (x *StarExpr) exprNode() {} func (*StarExpr) exprNode() {}
func (x *UnaryExpr) exprNode() {} func (*UnaryExpr) exprNode() {}
func (x *BinaryExpr) exprNode() {} func (*BinaryExpr) exprNode() {}
func (x *KeyValueExpr) exprNode() {} func (*KeyValueExpr) exprNode() {}
func (x *ArrayType) exprNode() {} func (*ArrayType) exprNode() {}
func (x *StructType) exprNode() {} func (*StructType) exprNode() {}
func (x *FuncType) exprNode() {} func (*FuncType) exprNode() {}
func (x *InterfaceType) exprNode() {} func (*InterfaceType) exprNode() {}
func (x *MapType) exprNode() {} func (*MapType) exprNode() {}
func (x *ChanType) exprNode() {} func (*ChanType) exprNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Convenience functions for Idents // Convenience functions for Idents
...@@ -711,27 +711,27 @@ func (s *RangeStmt) End() token.Pos { return s.Body.End() } ...@@ -711,27 +711,27 @@ func (s *RangeStmt) End() token.Pos { return s.Body.End() }
// stmtNode() ensures that only statement nodes can be // stmtNode() ensures that only statement nodes can be
// assigned to a StmtNode. // assigned to a StmtNode.
// //
func (s *BadStmt) stmtNode() {} func (*BadStmt) stmtNode() {}
func (s *DeclStmt) stmtNode() {} func (*DeclStmt) stmtNode() {}
func (s *EmptyStmt) stmtNode() {} func (*EmptyStmt) stmtNode() {}
func (s *LabeledStmt) stmtNode() {} func (*LabeledStmt) stmtNode() {}
func (s *ExprStmt) stmtNode() {} func (*ExprStmt) stmtNode() {}
func (s *SendStmt) stmtNode() {} func (*SendStmt) stmtNode() {}
func (s *IncDecStmt) stmtNode() {} func (*IncDecStmt) stmtNode() {}
func (s *AssignStmt) stmtNode() {} func (*AssignStmt) stmtNode() {}
func (s *GoStmt) stmtNode() {} func (*GoStmt) stmtNode() {}
func (s *DeferStmt) stmtNode() {} func (*DeferStmt) stmtNode() {}
func (s *ReturnStmt) stmtNode() {} func (*ReturnStmt) stmtNode() {}
func (s *BranchStmt) stmtNode() {} func (*BranchStmt) stmtNode() {}
func (s *BlockStmt) stmtNode() {} func (*BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {} func (*IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {} func (*CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {} func (*SwitchStmt) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {} func (*TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {} func (*CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {} func (*SelectStmt) stmtNode() {}
func (s *ForStmt) stmtNode() {} func (*ForStmt) stmtNode() {}
func (s *RangeStmt) stmtNode() {} func (*RangeStmt) stmtNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Declarations // Declarations
...@@ -807,9 +807,9 @@ func (s *TypeSpec) End() token.Pos { return s.Type.End() } ...@@ -807,9 +807,9 @@ func (s *TypeSpec) End() token.Pos { return s.Type.End() }
// specNode() ensures that only spec nodes can be // specNode() ensures that only spec nodes can be
// assigned to a Spec. // assigned to a Spec.
// //
func (s *ImportSpec) specNode() {} func (*ImportSpec) specNode() {}
func (s *ValueSpec) specNode() {} func (*ValueSpec) specNode() {}
func (s *TypeSpec) specNode() {} func (*TypeSpec) specNode() {}
// A declaration is represented by one of the following declaration nodes. // A declaration is represented by one of the following declaration nodes.
// //
...@@ -875,9 +875,9 @@ func (d *FuncDecl) End() token.Pos { ...@@ -875,9 +875,9 @@ func (d *FuncDecl) End() token.Pos {
// declNode() ensures that only declaration nodes can be // declNode() ensures that only declaration nodes can be
// assigned to a DeclNode. // assigned to a DeclNode.
// //
func (d *BadDecl) declNode() {} func (*BadDecl) declNode() {}
func (d *GenDecl) declNode() {} func (*GenDecl) declNode() {}
func (d *FuncDecl) declNode() {} func (*FuncDecl) declNode() {}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Files and packages // Files and packages
......
...@@ -24,7 +24,7 @@ func exportFilter(name string) bool { ...@@ -24,7 +24,7 @@ func exportFilter(name string) bool {
// it returns false otherwise. // it returns false otherwise.
// //
func FileExports(src *File) bool { func FileExports(src *File) bool {
return FilterFile(src, exportFilter) return filterFile(src, exportFilter, true)
} }
// PackageExports trims the AST for a Go package in place such that // PackageExports trims the AST for a Go package in place such that
...@@ -35,7 +35,7 @@ func FileExports(src *File) bool { ...@@ -35,7 +35,7 @@ func FileExports(src *File) bool {
// it returns false otherwise. // it returns false otherwise.
// //
func PackageExports(pkg *Package) bool { func PackageExports(pkg *Package) bool {
return FilterPackage(pkg, exportFilter) return filterPackage(pkg, exportFilter, true)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
...@@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident { ...@@ -72,7 +72,7 @@ func fieldName(x Expr) *Ident {
return nil return nil
} }
func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { func filterFieldList(fields *FieldList, filter Filter, export bool) (removedFields bool) {
if fields == nil { if fields == nil {
return false return false
} }
...@@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { ...@@ -93,8 +93,8 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
keepField = len(f.Names) > 0 keepField = len(f.Names) > 0
} }
if keepField { if keepField {
if filter == exportFilter { if export {
filterType(f.Type, filter) filterType(f.Type, filter, export)
} }
list[j] = f list[j] = f
j++ j++
...@@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { ...@@ -107,84 +107,84 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) {
return return
} }
func filterParamList(fields *FieldList, filter Filter) bool { func filterParamList(fields *FieldList, filter Filter, export bool) bool {
if fields == nil { if fields == nil {
return false return false
} }
var b bool var b bool
for _, f := range fields.List { for _, f := range fields.List {
if filterType(f.Type, filter) { if filterType(f.Type, filter, export) {
b = true b = true
} }
} }
return b return b
} }
func filterType(typ Expr, f Filter) bool { func filterType(typ Expr, f Filter, export bool) bool {
switch t := typ.(type) { switch t := typ.(type) {
case *Ident: case *Ident:
return f(t.Name) return f(t.Name)
case *ParenExpr: case *ParenExpr:
return filterType(t.X, f) return filterType(t.X, f, export)
case *ArrayType: case *ArrayType:
return filterType(t.Elt, f) return filterType(t.Elt, f, export)
case *StructType: case *StructType:
if filterFieldList(t.Fields, f) { if filterFieldList(t.Fields, f, export) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Fields.List) > 0 return len(t.Fields.List) > 0
case *FuncType: case *FuncType:
b1 := filterParamList(t.Params, f) b1 := filterParamList(t.Params, f, export)
b2 := filterParamList(t.Results, f) b2 := filterParamList(t.Results, f, export)
return b1 || b2 return b1 || b2
case *InterfaceType: case *InterfaceType:
if filterFieldList(t.Methods, f) { if filterFieldList(t.Methods, f, export) {
t.Incomplete = true t.Incomplete = true
} }
return len(t.Methods.List) > 0 return len(t.Methods.List) > 0
case *MapType: case *MapType:
b1 := filterType(t.Key, f) b1 := filterType(t.Key, f, export)
b2 := filterType(t.Value, f) b2 := filterType(t.Value, f, export)
return b1 || b2 return b1 || b2
case *ChanType: case *ChanType:
return filterType(t.Value, f) return filterType(t.Value, f, export)
} }
return false return false
} }
func filterSpec(spec Spec, f Filter) bool { func filterSpec(spec Spec, f Filter, export bool) bool {
switch s := spec.(type) { switch s := spec.(type) {
case *ValueSpec: case *ValueSpec:
s.Names = filterIdentList(s.Names, f) s.Names = filterIdentList(s.Names, f)
if len(s.Names) > 0 { if len(s.Names) > 0 {
if f == exportFilter { if export {
filterType(s.Type, f) filterType(s.Type, f, export)
} }
return true return true
} }
case *TypeSpec: case *TypeSpec:
if f(s.Name.Name) { if f(s.Name.Name) {
if f == exportFilter { if export {
filterType(s.Type, f) filterType(s.Type, f, export)
} }
return true return true
} }
if f != exportFilter { if !export {
// For general filtering (not just exports), // For general filtering (not just exports),
// filter type even if name is not filtered // filter type even if name is not filtered
// out. // out.
// If the type contains filtered elements, // If the type contains filtered elements,
// keep the declaration. // keep the declaration.
return filterType(s.Type, f) return filterType(s.Type, f, export)
} }
} }
return false return false
} }
func filterSpecList(list []Spec, f Filter) []Spec { func filterSpecList(list []Spec, f Filter, export bool) []Spec {
j := 0 j := 0
for _, s := range list { for _, s := range list {
if filterSpec(s, f) { if filterSpec(s, f, export) {
list[j] = s list[j] = s
j++ j++
} }
...@@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec { ...@@ -200,9 +200,13 @@ func filterSpecList(list []Spec, f Filter) []Spec {
// filtering; it returns false otherwise. // filtering; it returns false otherwise.
// //
func FilterDecl(decl Decl, f Filter) bool { func FilterDecl(decl Decl, f Filter) bool {
return filterDecl(decl, f, false)
}
func filterDecl(decl Decl, f Filter, export bool) bool {
switch d := decl.(type) { switch d := decl.(type) {
case *GenDecl: case *GenDecl:
d.Specs = filterSpecList(d.Specs, f) d.Specs = filterSpecList(d.Specs, f, export)
return len(d.Specs) > 0 return len(d.Specs) > 0
case *FuncDecl: case *FuncDecl:
return f(d.Name.Name) return f(d.Name.Name)
...@@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool { ...@@ -221,9 +225,13 @@ func FilterDecl(decl Decl, f Filter) bool {
// left after filtering; it returns false otherwise. // left after filtering; it returns false otherwise.
// //
func FilterFile(src *File, f Filter) bool { func FilterFile(src *File, f Filter) bool {
return filterFile(src, f, false)
}
func filterFile(src *File, f Filter, export bool) bool {
j := 0 j := 0
for _, d := range src.Decls { for _, d := range src.Decls {
if FilterDecl(d, f) { if filterDecl(d, f, export) {
src.Decls[j] = d src.Decls[j] = d
j++ j++
} }
...@@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool { ...@@ -244,9 +252,13 @@ func FilterFile(src *File, f Filter) bool {
// left after filtering; it returns false otherwise. // left after filtering; it returns false otherwise.
// //
func FilterPackage(pkg *Package, f Filter) bool { func FilterPackage(pkg *Package, f Filter) bool {
return filterPackage(pkg, f, false)
}
func filterPackage(pkg *Package, f Filter, export bool) bool {
hasDecls := false hasDecls := false
for _, src := range pkg.Files { for _, src := range pkg.Files {
if FilterFile(src, f) { if filterFile(src, f, export) {
hasDecls = true hasDecls = true
} }
} }
......
...@@ -37,18 +37,20 @@ var buildPkgs = []struct { ...@@ -37,18 +37,20 @@ var buildPkgs = []struct {
{ {
"go/build/cmdtest", "go/build/cmdtest",
&DirInfo{ &DirInfo{
GoFiles: []string{"main.go"}, GoFiles: []string{"main.go"},
Package: "main", Package: "main",
Imports: []string{"go/build/pkgtest"}, Imports: []string{"go/build/pkgtest"},
TestImports: []string{},
}, },
}, },
{ {
"go/build/cgotest", "go/build/cgotest",
&DirInfo{ &DirInfo{
CgoFiles: []string{"cgotest.go"}, CgoFiles: []string{"cgotest.go"},
CFiles: []string{"cgotest.c"}, CFiles: []string{"cgotest.c"},
Imports: []string{"C", "unsafe"}, Imports: []string{"C", "unsafe"},
Package: "cgotest", TestImports: []string{},
Package: "cgotest",
}, },
}, },
} }
......
...@@ -13,6 +13,8 @@ import ( ...@@ -13,6 +13,8 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"strings"
"text/tabwriter" "text/tabwriter"
) )
...@@ -244,6 +246,8 @@ func (p *printer) writeItem(pos token.Position, data string) { ...@@ -244,6 +246,8 @@ func (p *printer) writeItem(pos token.Position, data string) {
p.last = p.pos p.last = p.pos
} }
const linePrefix = "//line "
// writeCommentPrefix writes the whitespace before a comment. // writeCommentPrefix writes the whitespace before a comment.
// If there is any pending whitespace, it consumes as much of // If there is any pending whitespace, it consumes as much of
// it as is likely to help position the comment nicely. // it as is likely to help position the comment nicely.
...@@ -252,7 +256,7 @@ func (p *printer) writeItem(pos token.Position, data string) { ...@@ -252,7 +256,7 @@ func (p *printer) writeItem(pos token.Position, data string) {
// a group of comments (or nil), and isKeyword indicates if the // a group of comments (or nil), and isKeyword indicates if the
// next item is a keyword. // next item is a keyword.
// //
func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, isKeyword bool) { func (p *printer) writeCommentPrefix(pos, next token.Position, prev, comment *ast.Comment, isKeyword bool) {
if p.written == 0 { if p.written == 0 {
// the comment is the first item to be printed - don't write any whitespace // the comment is the first item to be printed - don't write any whitespace
return return
...@@ -337,6 +341,13 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment ...@@ -337,6 +341,13 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
} }
p.writeWhitespace(j) p.writeWhitespace(j)
} }
// turn off indent if we're about to print a line directive.
indent := p.indent
if strings.HasPrefix(comment.Text, linePrefix) {
p.indent = 0
}
// use formfeeds to break columns before a comment; // use formfeeds to break columns before a comment;
// this is analogous to using formfeeds to separate // this is analogous to using formfeeds to separate
// individual lines of /*-style comments - but make // individual lines of /*-style comments - but make
...@@ -347,6 +358,7 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment ...@@ -347,6 +358,7 @@ func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment
n = 1 n = 1
} }
p.writeNewlines(n, true) p.writeNewlines(n, true)
p.indent = indent
} }
} }
...@@ -526,6 +538,26 @@ func stripCommonPrefix(lines [][]byte) { ...@@ -526,6 +538,26 @@ func stripCommonPrefix(lines [][]byte) {
func (p *printer) writeComment(comment *ast.Comment) { func (p *printer) writeComment(comment *ast.Comment) {
text := comment.Text text := comment.Text
if strings.HasPrefix(text, linePrefix) {
pos := strings.TrimSpace(text[len(linePrefix):])
i := strings.LastIndex(pos, ":")
if i >= 0 {
// The line directive we are about to print changed
// the Filename and Line number used by go/token
// as it was reading the input originally.
// In order to match the original input, we have to
// update our own idea of the file and line number
// accordingly, after printing the directive.
file := pos[:i]
line, _ := strconv.Atoi(string(pos[i+1:]))
defer func() {
p.pos.Filename = string(file)
p.pos.Line = line
p.pos.Column = 1
}()
}
}
// shortcut common case of //-style comments // shortcut common case of //-style comments
if text[1] == '/' { if text[1] == '/' {
p.writeItem(p.fset.Position(comment.Pos()), p.escape(text)) p.writeItem(p.fset.Position(comment.Pos()), p.escape(text))
...@@ -599,7 +631,7 @@ func (p *printer) intersperseComments(next token.Position, tok token.Token) (dro ...@@ -599,7 +631,7 @@ func (p *printer) intersperseComments(next token.Position, tok token.Token) (dro
var last *ast.Comment var last *ast.Comment
for ; p.commentBefore(next); p.cindex++ { for ; p.commentBefore(next); p.cindex++ {
for _, c := range p.comments[p.cindex].List { for _, c := range p.comments[p.cindex].List {
p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, tok.IsKeyword()) p.writeCommentPrefix(p.fset.Position(c.Pos()), next, last, c, tok.IsKeyword())
p.writeComment(c) p.writeComment(c)
last = c last = c
} }
......
...@@ -37,7 +37,7 @@ lower-cased, and attributes are collected into a []Attribute. For example: ...@@ -37,7 +37,7 @@ lower-cased, and attributes are collected into a []Attribute. For example:
for { for {
if z.Next() == html.ErrorToken { if z.Next() == html.ErrorToken {
// Returning io.EOF indicates success. // Returning io.EOF indicates success.
return z.Error() return z.Err()
} }
emitToken(z.Token()) emitToken(z.Token())
} }
...@@ -51,7 +51,7 @@ call to Next. For example, to extract an HTML page's anchor text: ...@@ -51,7 +51,7 @@ call to Next. For example, to extract an HTML page's anchor text:
tt := z.Next() tt := z.Next()
switch tt { switch tt {
case ErrorToken: case ErrorToken:
return z.Error() return z.Err()
case TextToken: case TextToken:
if depth > 0 { if depth > 0 {
// emitBytes should copy the []byte it receives, // emitBytes should copy the []byte it receives,
......
...@@ -133,8 +133,8 @@ func TestParser(t *testing.T) { ...@@ -133,8 +133,8 @@ func TestParser(t *testing.T) {
n int n int
}{ }{
// TODO(nigeltao): Process all the test cases from all the .dat files. // TODO(nigeltao): Process all the test cases from all the .dat files.
{"tests1.dat", 92}, {"tests1.dat", -1},
{"tests2.dat", 0}, {"tests2.dat", 43},
{"tests3.dat", 0}, {"tests3.dat", 0},
} }
for _, tf := range testFiles { for _, tf := range testFiles {
...@@ -213,4 +213,8 @@ var renderTestBlacklist = map[string]bool{ ...@@ -213,4 +213,8 @@ var renderTestBlacklist = map[string]bool{
// More cases of <a> being reparented: // More cases of <a> being reparented:
`<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true, `<a href="blah">aba<table><a href="foo">br<tr><td></td></tr>x</table>aoe`: true,
`<a><table><a></table><p><a><div><a>`: true, `<a><table><a></table><p><a><div><a>`: true,
`<a><table><td><a><table></table><a></tr><a></table><a>`: true,
// A <plaintext> element is reparented, putting it before a table.
// A <plaintext> element can't have anything after it in HTML.
`<table><plaintext><td>`: true,
} }
...@@ -52,7 +52,19 @@ func Render(w io.Writer, n *Node) error { ...@@ -52,7 +52,19 @@ func Render(w io.Writer, n *Node) error {
return buf.Flush() return buf.Flush()
} }
// plaintextAbort is returned from render1 when a <plaintext> element
// has been rendered. No more end tags should be rendered after that.
var plaintextAbort = errors.New("html: internal error (plaintext abort)")
func render(w writer, n *Node) error { func render(w writer, n *Node) error {
err := render1(w, n)
if err == plaintextAbort {
err = nil
}
return err
}
func render1(w writer, n *Node) error {
// Render non-element nodes; these are the easy cases. // Render non-element nodes; these are the easy cases.
switch n.Type { switch n.Type {
case ErrorNode: case ErrorNode:
...@@ -61,7 +73,7 @@ func render(w writer, n *Node) error { ...@@ -61,7 +73,7 @@ func render(w writer, n *Node) error {
return escape(w, n.Data) return escape(w, n.Data)
case DocumentNode: case DocumentNode:
for _, c := range n.Child { for _, c := range n.Child {
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }
...@@ -128,7 +140,7 @@ func render(w writer, n *Node) error { ...@@ -128,7 +140,7 @@ func render(w writer, n *Node) error {
// Render any child nodes. // Render any child nodes.
switch n.Data { switch n.Data {
case "noembed", "noframes", "noscript", "script", "style": case "noembed", "noframes", "noscript", "plaintext", "script", "style":
for _, c := range n.Child { for _, c := range n.Child {
if c.Type != TextNode { if c.Type != TextNode {
return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data) return fmt.Errorf("html: raw text element <%s> has non-text child node", n.Data)
...@@ -137,18 +149,23 @@ func render(w writer, n *Node) error { ...@@ -137,18 +149,23 @@ func render(w writer, n *Node) error {
return err return err
} }
} }
if n.Data == "plaintext" {
// Don't render anything else. <plaintext> must be the
// last element in the file, with no closing tag.
return plaintextAbort
}
case "textarea", "title": case "textarea", "title":
for _, c := range n.Child { for _, c := range n.Child {
if c.Type != TextNode { if c.Type != TextNode {
return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data) return fmt.Errorf("html: RCDATA element <%s> has non-text child node", n.Data)
} }
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }
default: default:
for _, c := range n.Child { for _, c := range n.Child {
if err := render(w, c); err != nil { if err := render1(w, c); err != nil {
return err return err
} }
} }
......
...@@ -6,6 +6,7 @@ package template ...@@ -6,6 +6,7 @@ package template
import ( import (
"fmt" "fmt"
"reflect"
) )
// Strings of content from a trusted source. // Strings of content from a trusted source.
...@@ -70,10 +71,25 @@ const ( ...@@ -70,10 +71,25 @@ const (
contentTypeUnsafe contentTypeUnsafe
) )
// indirect returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil).
func indirect(a interface{}) interface{} {
if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
// Avoid creating a reflect.Value if it's not a pointer.
return a
}
v := reflect.ValueOf(a)
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// stringify converts its arguments to a string and the type of the content. // stringify converts its arguments to a string and the type of the content.
// All pointers are dereferenced, as in the text/template package.
func stringify(args ...interface{}) (string, contentType) { func stringify(args ...interface{}) (string, contentType) {
if len(args) == 1 { if len(args) == 1 {
switch s := args[0].(type) { switch s := indirect(args[0]).(type) {
case string: case string:
return s, contentTypePlain return s, contentTypePlain
case CSS: case CSS:
...@@ -90,5 +106,8 @@ func stringify(args ...interface{}) (string, contentType) { ...@@ -90,5 +106,8 @@ func stringify(args ...interface{}) (string, contentType) {
return string(s), contentTypeURL return string(s), contentTypeURL
} }
} }
for i, arg := range args {
args[i] = indirect(arg)
}
return fmt.Sprint(args...), contentTypePlain return fmt.Sprint(args...), contentTypePlain
} }
...@@ -28,7 +28,7 @@ func (x *goodMarshaler) MarshalJSON() ([]byte, error) { ...@@ -28,7 +28,7 @@ func (x *goodMarshaler) MarshalJSON() ([]byte, error) {
} }
func TestEscape(t *testing.T) { func TestEscape(t *testing.T) {
var data = struct { data := struct {
F, T bool F, T bool
C, G, H string C, G, H string
A, E []string A, E []string
...@@ -50,6 +50,7 @@ func TestEscape(t *testing.T) { ...@@ -50,6 +50,7 @@ func TestEscape(t *testing.T) {
Z: nil, Z: nil,
W: HTML(`&iexcl;<b class="foo">Hello</b>, <textarea>O'World</textarea>!`), W: HTML(`&iexcl;<b class="foo">Hello</b>, <textarea>O'World</textarea>!`),
} }
pdata := &data
tests := []struct { tests := []struct {
name string name string
...@@ -668,6 +669,15 @@ func TestEscape(t *testing.T) { ...@@ -668,6 +669,15 @@ func TestEscape(t *testing.T) {
t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g) t.Errorf("%s: escaped output: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue continue
} }
b.Reset()
if err := tmpl.Execute(b, pdata); err != nil {
t.Errorf("%s: template execution failed for pointer: %s", test.name, err)
continue
}
if w, g := test.output, b.String(); w != g {
t.Errorf("%s: escaped output for pointer: want\n\t%q\ngot\n\t%q", test.name, w, g)
continue
}
} }
} }
...@@ -1605,6 +1615,29 @@ func TestRedundantFuncs(t *testing.T) { ...@@ -1605,6 +1615,29 @@ func TestRedundantFuncs(t *testing.T) {
} }
} }
func TestIndirectPrint(t *testing.T) {
a := 3
ap := &a
b := "hello"
bp := &b
bpp := &bp
tmpl := Must(New("t").Parse(`{{.}}`))
var buf bytes.Buffer
err := tmpl.Execute(&buf, ap)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "3" {
t.Errorf(`Expected "3"; got %q`, buf.String())
}
buf.Reset()
err = tmpl.Execute(&buf, bpp)
if err != nil {
t.Errorf("Unexpected error: %s", err)
} else if buf.String() != "hello" {
t.Errorf(`Expected "hello"; got %q`, buf.String())
}
}
func BenchmarkEscapedExecute(b *testing.B) { func BenchmarkEscapedExecute(b *testing.B) {
tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`)) tmpl := Must(New("t").Parse(`<a onclick="alert('{{.}}')">{{.}}</a>`))
var buf bytes.Buffer var buf bytes.Buffer
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
...@@ -117,12 +118,24 @@ var regexpPrecederKeywords = map[string]bool{ ...@@ -117,12 +118,24 @@ var regexpPrecederKeywords = map[string]bool{
"void": true, "void": true,
} }
var jsonMarshalType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
// indirectToJSONMarshaler returns the value, after dereferencing as many times
// as necessary to reach the base type (or nil) or an implementation of json.Marshal.
func indirectToJSONMarshaler(a interface{}) interface{} {
v := reflect.ValueOf(a)
for !v.Type().Implements(jsonMarshalType) && v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
return v.Interface()
}
// jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has // jsValEscaper escapes its inputs to a JS Expression (section 11.14) that has
// nether side-effects nor free variables outside (NaN, Infinity). // neither side-effects nor free variables outside (NaN, Infinity).
func jsValEscaper(args ...interface{}) string { func jsValEscaper(args ...interface{}) string {
var a interface{} var a interface{}
if len(args) == 1 { if len(args) == 1 {
a = args[0] a = indirectToJSONMarshaler(args[0])
switch t := a.(type) { switch t := a.(type) {
case JS: case JS:
return string(t) return string(t)
...@@ -135,6 +148,9 @@ func jsValEscaper(args ...interface{}) string { ...@@ -135,6 +148,9 @@ func jsValEscaper(args ...interface{}) string {
a = t.String() a = t.String()
} }
} else { } else {
for i, arg := range args {
args[i] = indirectToJSONMarshaler(arg)
}
a = fmt.Sprint(args...) a = fmt.Sprint(args...)
} }
// TODO: detect cycles before calling Marshal which loops infinitely on // TODO: detect cycles before calling Marshal which loops infinitely on
......
...@@ -401,14 +401,14 @@ func (z *Tokenizer) readStartTag() TokenType { ...@@ -401,14 +401,14 @@ func (z *Tokenizer) readStartTag() TokenType {
break break
} }
} }
// Any "<noembed>", "<noframes>", "<noscript>", "<script>", "<style>", // Any "<noembed>", "<noframes>", "<noscript>", "<plaintext", "<script>", "<style>",
// "<textarea>" or "<title>" tag flags the tokenizer's next token as raw. // "<textarea>" or "<title>" tag flags the tokenizer's next token as raw.
// The tag name lengths of these special cases ranges in [5, 8]. // The tag name lengths of these special cases ranges in [5, 9].
if x := z.data.end - z.data.start; 5 <= x && x <= 8 { if x := z.data.end - z.data.start; 5 <= x && x <= 9 {
switch z.buf[z.data.start] { switch z.buf[z.data.start] {
case 'n', 's', 't', 'N', 'S', 'T': case 'n', 'p', 's', 't', 'N', 'P', 'S', 'T':
switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s { switch s := strings.ToLower(string(z.buf[z.data.start:z.data.end])); s {
case "noembed", "noframes", "noscript", "script", "style", "textarea", "title": case "noembed", "noframes", "noscript", "plaintext", "script", "style", "textarea", "title":
z.rawTag = s z.rawTag = s
} }
} }
...@@ -551,9 +551,19 @@ func (z *Tokenizer) Next() TokenType { ...@@ -551,9 +551,19 @@ func (z *Tokenizer) Next() TokenType {
z.data.start = z.raw.end z.data.start = z.raw.end
z.data.end = z.raw.end z.data.end = z.raw.end
if z.rawTag != "" { if z.rawTag != "" {
z.readRawOrRCDATA() if z.rawTag == "plaintext" {
z.tt = TextToken // Read everything up to EOF.
return z.tt for z.err == nil {
z.readByte()
}
z.textIsRaw = true
} else {
z.readRawOrRCDATA()
}
if z.data.end > z.data.start {
z.tt = TextToken
return z.tt
}
} }
z.textIsRaw = false z.textIsRaw = false
......
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
package tiff package tiff
import ( import "io"
"io"
"os"
)
// buffer buffers an io.Reader to satisfy io.ReaderAt. // buffer buffers an io.Reader to satisfy io.ReaderAt.
type buffer struct { type buffer struct {
...@@ -19,7 +16,7 @@ func (b *buffer) ReadAt(p []byte, off int64) (int, error) { ...@@ -19,7 +16,7 @@ func (b *buffer) ReadAt(p []byte, off int64) (int, error) {
o := int(off) o := int(off)
end := o + len(p) end := o + len(p)
if int64(end) != off+int64(len(p)) { if int64(end) != off+int64(len(p)) {
return 0, os.EINVAL return 0, io.ErrUnexpectedEOF
} }
m := len(b.buf) m := len(b.buf)
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"time"
) )
// Random number state, accessed without lock; racy but harmless. // Random number state, accessed without lock; racy but harmless.
...@@ -17,8 +18,7 @@ import ( ...@@ -17,8 +18,7 @@ import (
var rand uint32 var rand uint32
func reseed() uint32 { func reseed() uint32 {
sec, nsec, _ := os.Time() return uint32(time.Nanoseconds() + int64(os.Getpid()))
return uint32(sec*1e9 + nsec + int64(os.Getpid()))
} }
func nextSuffix() string { func nextSuffix() string {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
package syslog package syslog
import ( import (
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
...@@ -75,7 +76,7 @@ func Dial(network, raddr string, priority Priority, prefix string) (w *Writer, e ...@@ -75,7 +76,7 @@ func Dial(network, raddr string, priority Priority, prefix string) (w *Writer, e
// Write sends a log message to the syslog daemon. // Write sends a log message to the syslog daemon.
func (w *Writer) Write(b []byte) (int, error) { func (w *Writer) Write(b []byte) (int, error) {
if w.priority > LOG_DEBUG || w.priority < LOG_EMERG { if w.priority > LOG_DEBUG || w.priority < LOG_EMERG {
return 0, os.EINVAL return 0, errors.New("log/syslog: invalid priority")
} }
return w.conn.writeBytes(w.priority, w.prefix, b) return w.conn.writeBytes(w.priority, w.prefix, b)
} }
......
...@@ -176,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int { ...@@ -176,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int {
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see QuoRem for more details. // Rem implements truncated modulus (like Go); see QuoRem for more details.
func (z *Int) Rem(x, y *Int) *Int { func (z *Int) Rem(x, y *Int) *Int {
_, z.abs = nat{}.div(z.abs, x.abs, y.abs) _, z.abs = nat(nil).div(z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z return z
} }
...@@ -678,14 +678,14 @@ func (z *Int) Bit(i int) uint { ...@@ -678,14 +678,14 @@ func (z *Int) Bit(i int) uint {
panic("negative bit index") panic("negative bit index")
} }
if z.neg { if z.neg {
t := nat{}.sub(z.abs, natOne) t := nat(nil).sub(z.abs, natOne)
return t.bit(uint(i)) ^ 1 return t.bit(uint(i)) ^ 1
} }
return z.abs.bit(uint(i)) return z.abs.bit(uint(i))
} }
// SetBit sets the i'th bit of z to bit and returns z. // SetBit sets z to x, with x's i'th bit set to b (0 or 1).
// That is, if bit is 1 SetBit sets z = x | (1 << i); // That is, if bit is 1 SetBit sets z = x | (1 << i);
// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1, // if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
// SetBit will panic. // SetBit will panic.
...@@ -710,8 +710,8 @@ func (z *Int) And(x, y *Int) *Int { ...@@ -710,8 +710,8 @@ func (z *Int) And(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y1), natOne) z.abs = z.abs.add(z.abs.or(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative z.neg = true // z cannot be zero if x and y are negative
return z return z
...@@ -729,7 +729,7 @@ func (z *Int) And(x, y *Int) *Int { ...@@ -729,7 +729,7 @@ func (z *Int) And(x, y *Int) *Int {
} }
// x & (-y) == x & ^(y-1) == x &^ (y-1) // x & (-y) == x & ^(y-1) == x &^ (y-1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(x.abs, y1) z.abs = z.abs.andNot(x.abs, y1)
z.neg = false z.neg = false
return z return z
...@@ -740,8 +740,8 @@ func (z *Int) AndNot(x, y *Int) *Int { ...@@ -740,8 +740,8 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1) // (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.andNot(y1, x1) z.abs = z.abs.andNot(y1, x1)
z.neg = false z.neg = false
return z return z
...@@ -755,14 +755,14 @@ func (z *Int) AndNot(x, y *Int) *Int { ...@@ -755,14 +755,14 @@ func (z *Int) AndNot(x, y *Int) *Int {
if x.neg { if x.neg {
// (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1) // (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne) z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne)
z.neg = true // z cannot be zero if x is negative and y is positive z.neg = true // z cannot be zero if x is negative and y is positive
return z return z
} }
// x &^ (-y) == x &^ ^(y-1) == x & (y-1) // x &^ (-y) == x &^ ^(y-1) == x & (y-1)
y1 := nat{}.add(y.abs, natOne) y1 := nat(nil).add(y.abs, natOne)
z.abs = z.abs.and(x.abs, y1) z.abs = z.abs.and(x.abs, y1)
z.neg = false z.neg = false
return z return z
...@@ -773,8 +773,8 @@ func (z *Int) Or(x, y *Int) *Int { ...@@ -773,8 +773,8 @@ func (z *Int) Or(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.and(x1, y1), natOne) z.abs = z.abs.add(z.abs.and(x1, y1), natOne)
z.neg = true // z cannot be zero if x and y are negative z.neg = true // z cannot be zero if x and y are negative
return z return z
...@@ -792,7 +792,7 @@ func (z *Int) Or(x, y *Int) *Int { ...@@ -792,7 +792,7 @@ func (z *Int) Or(x, y *Int) *Int {
} }
// x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne) z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne)
z.neg = true // z cannot be zero if one of x or y is negative z.neg = true // z cannot be zero if one of x or y is negative
return z return z
...@@ -803,8 +803,8 @@ func (z *Int) Xor(x, y *Int) *Int { ...@@ -803,8 +803,8 @@ func (z *Int) Xor(x, y *Int) *Int {
if x.neg == y.neg { if x.neg == y.neg {
if x.neg { if x.neg {
// (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1) // (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1)
x1 := nat{}.sub(x.abs, natOne) x1 := nat(nil).sub(x.abs, natOne)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.xor(x1, y1) z.abs = z.abs.xor(x1, y1)
z.neg = false z.neg = false
return z return z
...@@ -822,7 +822,7 @@ func (z *Int) Xor(x, y *Int) *Int { ...@@ -822,7 +822,7 @@ func (z *Int) Xor(x, y *Int) *Int {
} }
// x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1) // x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1)
y1 := nat{}.sub(y.abs, natOne) y1 := nat(nil).sub(y.abs, natOne)
z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne) z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne)
z.neg = true // z cannot be zero if only one of x or y is negative z.neg = true // z cannot be zero if only one of x or y is negative
return z return z
......
...@@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat { ...@@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat {
case a == b: case a == b:
return z.setUint64(a) return z.setUint64(a)
case a+1 == b: case a+1 == b:
return z.mul(nat{}.setUint64(a), nat{}.setUint64(b)) return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
} }
m := (a + b) / 2 m := (a + b) / 2
return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b)) return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
} }
// q = (x-r)/y, with 0 <= r < y // q = (x-r)/y, with 0 <= r < y
...@@ -785,7 +785,7 @@ func (x nat) string(charset string) string { ...@@ -785,7 +785,7 @@ func (x nat) string(charset string) string {
} }
// preserve x, create local copy for use in repeated divisions // preserve x, create local copy for use in repeated divisions
q := nat{}.set(x) q := nat(nil).set(x)
var r Word var r Word
// convert // convert
...@@ -1191,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool { ...@@ -1191,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool {
return false return false
} }
nm1 := nat{}.sub(n, natOne) nm1 := nat(nil).sub(n, natOne)
// 1<<k * q = nm1; // 1<<k * q = nm1;
q, k := nm1.powersOfTwoDecompose() q, k := nm1.powersOfTwoDecompose()
nm3 := nat{}.sub(nm1, natTwo) nm3 := nat(nil).sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0]))) rand := rand.New(rand.NewSource(int64(n[0])))
var x, y, quotient nat var x, y, quotient nat
......
...@@ -16,9 +16,9 @@ var cmpTests = []struct { ...@@ -16,9 +16,9 @@ var cmpTests = []struct {
r int r int
}{ }{
{nil, nil, 0}, {nil, nil, 0},
{nil, nat{}, 0}, {nil, nat(nil), 0},
{nat{}, nil, 0}, {nat(nil), nil, 0},
{nat{}, nat{}, 0}, {nat(nil), nat(nil), 0},
{nat{0}, nat{0}, 0}, {nat{0}, nat{0}, 0},
{nat{0}, nat{1}, -1}, {nat{0}, nat{1}, -1},
{nat{1}, nat{0}, 1}, {nat{1}, nat{0}, 1},
...@@ -67,7 +67,7 @@ var prodNN = []argNN{ ...@@ -67,7 +67,7 @@ var prodNN = []argNN{
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
for _, a := range sumNN { for _, a := range sumNN {
z := nat{}.set(a.z) z := nat(nil).set(a.z)
if z.cmp(a.z) != 0 { if z.cmp(a.z) != 0 {
t.Errorf("got z = %v; want %v", z, a.z) t.Errorf("got z = %v; want %v", z, a.z)
} }
...@@ -129,7 +129,7 @@ var mulRangesN = []struct { ...@@ -129,7 +129,7 @@ var mulRangesN = []struct {
func TestMulRangeN(t *testing.T) { func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN { for i, r := range mulRangesN {
prod := nat{}.mulRange(r.a, r.b).decimalString() prod := nat(nil).mulRange(r.a, r.b).decimalString()
if prod != r.prod { if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod) t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
} }
...@@ -175,7 +175,7 @@ func toString(x nat, charset string) string { ...@@ -175,7 +175,7 @@ func toString(x nat, charset string) string {
s := make([]byte, i) s := make([]byte, i)
// don't destroy x // don't destroy x
q := nat{}.set(x) q := nat(nil).set(x)
// convert // convert
for len(q) > 0 { for len(q) > 0 {
...@@ -212,7 +212,7 @@ func TestString(t *testing.T) { ...@@ -212,7 +212,7 @@ func TestString(t *testing.T) {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
} }
x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c)) x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 { if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
} }
...@@ -271,7 +271,7 @@ var natScanTests = []struct { ...@@ -271,7 +271,7 @@ var natScanTests = []struct {
func TestScanBase(t *testing.T) { func TestScanBase(t *testing.T) {
for _, a := range natScanTests { for _, a := range natScanTests {
r := strings.NewReader(a.s) r := strings.NewReader(a.s)
x, b, err := nat{}.scan(r, a.base) x, b, err := nat(nil).scan(r, a.base)
if err == nil && !a.ok { if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a) t.Errorf("scan%+v\n\texpected error", a)
} }
...@@ -651,17 +651,17 @@ var expNNTests = []struct { ...@@ -651,17 +651,17 @@ var expNNTests = []struct {
func TestExpNN(t *testing.T) { func TestExpNN(t *testing.T) {
for i, test := range expNNTests { for i, test := range expNNTests {
x, _, _ := nat{}.scan(strings.NewReader(test.x), 0) x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0)
y, _, _ := nat{}.scan(strings.NewReader(test.y), 0) y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0)
out, _, _ := nat{}.scan(strings.NewReader(test.out), 0) out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0)
var m nat var m nat
if len(test.m) > 0 { if len(test.m) > 0 {
m, _, _ = nat{}.scan(strings.NewReader(test.m), 0) m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0)
} }
z := nat{}.expNN(x, y, m) z := nat(nil).expNN(x, y, m)
if z.cmp(out) != 0 { if z.cmp(out) != 0 {
t.Errorf("#%d got %v want %v", i, z, out) t.Errorf("#%d got %v want %v", i, z, out)
} }
......
...@@ -33,7 +33,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat { ...@@ -33,7 +33,7 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
panic("division by zero") panic("division by zero")
} }
if &z.a == b || alias(z.a.abs, babs) { if &z.a == b || alias(z.a.abs, babs) {
babs = nat{}.set(babs) // make a copy babs = nat(nil).set(babs) // make a copy
} }
z.a.abs = z.a.abs.set(a.abs) z.a.abs = z.a.abs.set(a.abs)
z.b = z.b.set(babs) z.b = z.b.set(babs)
...@@ -315,7 +315,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -315,7 +315,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
if _, ok := z.a.SetString(s, 10); !ok { if _, ok := z.a.SetString(s, 10); !ok {
return nil, false return nil, false
} }
powTen := nat{}.expNN(natTen, exp.abs, nil) powTen := nat(nil).expNN(natTen, exp.abs, nil)
if exp.neg { if exp.neg {
z.b = powTen z.b = powTen
z.norm() z.norm()
...@@ -357,23 +357,23 @@ func (z *Rat) FloatString(prec int) string { ...@@ -357,23 +357,23 @@ func (z *Rat) FloatString(prec int) string {
} }
// z.b != 0 // z.b != 0
q, r := nat{}.div(nat{}, z.a.abs, z.b) q, r := nat(nil).div(nat(nil), z.a.abs, z.b)
p := natOne p := natOne
if prec > 0 { if prec > 0 {
p = nat{}.expNN(natTen, nat{}.setUint64(uint64(prec)), nil) p = nat(nil).expNN(natTen, nat(nil).setUint64(uint64(prec)), nil)
} }
r = r.mul(r, p) r = r.mul(r, p)
r, r2 := r.div(nat{}, r, z.b) r, r2 := r.div(nat(nil), r, z.b)
// see if we need to round up // see if we need to round up
r2 = r2.add(r2, r2) r2 = r2.add(r2, r2)
if z.b.cmp(r2) <= 0 { if z.b.cmp(r2) <= 0 {
r = r.add(r, natOne) r = r.add(r, natOne)
if r.cmp(p) >= 0 { if r.cmp(p) >= 0 {
q = nat{}.add(q, natOne) q = nat(nil).add(q, natOne)
r = nat{}.sub(r, p) r = nat(nil).sub(r, p)
} }
} }
......
...@@ -63,7 +63,7 @@ package math ...@@ -63,7 +63,7 @@ package math
// Stephen L. Moshier // Stephen L. Moshier
// moshier@na-net.ornl.gov // moshier@na-net.ornl.gov
var _P = [...]float64{ var _gamP = [...]float64{
1.60119522476751861407e-04, 1.60119522476751861407e-04,
1.19135147006586384913e-03, 1.19135147006586384913e-03,
1.04213797561761569935e-02, 1.04213797561761569935e-02,
...@@ -72,7 +72,7 @@ var _P = [...]float64{ ...@@ -72,7 +72,7 @@ var _P = [...]float64{
4.94214826801497100753e-01, 4.94214826801497100753e-01,
9.99999999999999996796e-01, 9.99999999999999996796e-01,
} }
var _Q = [...]float64{ var _gamQ = [...]float64{
-2.31581873324120129819e-05, -2.31581873324120129819e-05,
5.39605580493303397842e-04, 5.39605580493303397842e-04,
-4.45641913851797240494e-03, -4.45641913851797240494e-03,
...@@ -82,7 +82,7 @@ var _Q = [...]float64{ ...@@ -82,7 +82,7 @@ var _Q = [...]float64{
7.14304917030273074085e-02, 7.14304917030273074085e-02,
1.00000000000000000320e+00, 1.00000000000000000320e+00,
} }
var _S = [...]float64{ var _gamS = [...]float64{
7.87311395793093628397e-04, 7.87311395793093628397e-04,
-2.29549961613378126380e-04, -2.29549961613378126380e-04,
-2.68132617805781232825e-03, -2.68132617805781232825e-03,
...@@ -98,7 +98,7 @@ func stirling(x float64) float64 { ...@@ -98,7 +98,7 @@ func stirling(x float64) float64 {
MaxStirling = 143.01608 MaxStirling = 143.01608
) )
w := 1 / x w := 1 / x
w = 1 + w*((((_S[0]*w+_S[1])*w+_S[2])*w+_S[3])*w+_S[4]) w = 1 + w*((((_gamS[0]*w+_gamS[1])*w+_gamS[2])*w+_gamS[3])*w+_gamS[4])
y := Exp(x) y := Exp(x)
if x > MaxStirling { // avoid Pow() overflow if x > MaxStirling { // avoid Pow() overflow
v := Pow(x, 0.5*x-0.25) v := Pow(x, 0.5*x-0.25)
...@@ -176,8 +176,8 @@ func Gamma(x float64) float64 { ...@@ -176,8 +176,8 @@ func Gamma(x float64) float64 {
} }
x = x - 2 x = x - 2
p = (((((x*_P[0]+_P[1])*x+_P[2])*x+_P[3])*x+_P[4])*x+_P[5])*x + _P[6] p = (((((x*_gamP[0]+_gamP[1])*x+_gamP[2])*x+_gamP[3])*x+_gamP[4])*x+_gamP[5])*x + _gamP[6]
q = ((((((x*_Q[0]+_Q[1])*x+_Q[2])*x+_Q[3])*x+_Q[4])*x+_Q[5])*x+_Q[6])*x + _Q[7] q = ((((((x*_gamQ[0]+_gamQ[1])*x+_gamQ[2])*x+_gamQ[3])*x+_gamQ[4])*x+_gamQ[5])*x+_gamQ[6])*x + _gamQ[7]
return z * p / q return z * p / q
small: small:
......
...@@ -88,6 +88,81 @@ package math ...@@ -88,6 +88,81 @@ package math
// //
// //
var _lgamA = [...]float64{
7.72156649015328655494e-02, // 0x3FB3C467E37DB0C8
3.22467033424113591611e-01, // 0x3FD4A34CC4A60FAD
6.73523010531292681824e-02, // 0x3FB13E001A5562A7
2.05808084325167332806e-02, // 0x3F951322AC92547B
7.38555086081402883957e-03, // 0x3F7E404FB68FEFE8
2.89051383673415629091e-03, // 0x3F67ADD8CCB7926B
1.19270763183362067845e-03, // 0x3F538A94116F3F5D
5.10069792153511336608e-04, // 0x3F40B6C689B99C00
2.20862790713908385557e-04, // 0x3F2CF2ECED10E54D
1.08011567247583939954e-04, // 0x3F1C5088987DFB07
2.52144565451257326939e-05, // 0x3EFA7074428CFA52
4.48640949618915160150e-05, // 0x3F07858E90A45837
}
var _lgamR = [...]float64{
1.0, // placeholder
1.39200533467621045958e+00, // 0x3FF645A762C4AB74
7.21935547567138069525e-01, // 0x3FE71A1893D3DCDC
1.71933865632803078993e-01, // 0x3FC601EDCCFBDF27
1.86459191715652901344e-02, // 0x3F9317EA742ED475
7.77942496381893596434e-04, // 0x3F497DDACA41A95B
7.32668430744625636189e-06, // 0x3EDEBAF7A5B38140
}
var _lgamS = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
2.14982415960608852501e-01, // 0x3FCB848B36E20878
3.25778796408930981787e-01, // 0x3FD4D98F4F139F59
1.46350472652464452805e-01, // 0x3FC2BB9CBEE5F2F7
2.66422703033638609560e-02, // 0x3F9B481C7E939961
1.84028451407337715652e-03, // 0x3F5E26B67368F239
3.19475326584100867617e-05, // 0x3F00BFECDD17E945
}
var _lgamT = [...]float64{
4.83836122723810047042e-01, // 0x3FDEF72BC8EE38A2
-1.47587722994593911752e-01, // 0xBFC2E4278DC6C509
6.46249402391333854778e-02, // 0x3FB08B4294D5419B
-3.27885410759859649565e-02, // 0xBFA0C9A8DF35B713
1.79706750811820387126e-02, // 0x3F9266E7970AF9EC
-1.03142241298341437450e-02, // 0xBF851F9FBA91EC6A
6.10053870246291332635e-03, // 0x3F78FCE0E370E344
-3.68452016781138256760e-03, // 0xBF6E2EFFB3E914D7
2.25964780900612472250e-03, // 0x3F6282D32E15C915
-1.40346469989232843813e-03, // 0xBF56FE8EBF2D1AF1
8.81081882437654011382e-04, // 0x3F4CDF0CEF61A8E9
-5.38595305356740546715e-04, // 0xBF41A6109C73E0EC
3.15632070903625950361e-04, // 0x3F34AF6D6C0EBBF7
-3.12754168375120860518e-04, // 0xBF347F24ECC38C38
3.35529192635519073543e-04, // 0x3F35FD3EE8C2D3F4
}
var _lgamU = [...]float64{
-7.72156649015328655494e-02, // 0xBFB3C467E37DB0C8
6.32827064025093366517e-01, // 0x3FE4401E8B005DFF
1.45492250137234768737e+00, // 0x3FF7475CD119BD6F
9.77717527963372745603e-01, // 0x3FEF497644EA8450
2.28963728064692451092e-01, // 0x3FCD4EAEF6010924
1.33810918536787660377e-02, // 0x3F8B678BBF2BAB09
}
var _lgamV = [...]float64{
1.0,
2.45597793713041134822e+00, // 0x4003A5D7C2BD619C
2.12848976379893395361e+00, // 0x40010725A42B18F5
7.69285150456672783825e-01, // 0x3FE89DFBE45050AF
1.04222645593369134254e-01, // 0x3FBAAE55D6537C88
3.21709242282423911810e-03, // 0x3F6A5ABB57D0CF61
}
var _lgamW = [...]float64{
4.18938533204672725052e-01, // 0x3FDACFE390C97D69
8.33333333333329678849e-02, // 0x3FB555555555553B
-2.77777777728775536470e-03, // 0xBF66C16C16B02E5C
7.93650558643019558500e-04, // 0x3F4A019F98CF38B6
-5.95187557450339963135e-04, // 0xBF4380CB8C0FE741
8.36339918996282139126e-04, // 0x3F4B67BA4CDAD5D1
-1.63092934096575273989e-03, // 0xBF5AB89D0B9E43E4
}
// Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x). // Lgamma returns the natural logarithm and sign (-1 or +1) of Gamma(x).
// //
// Special cases are: // Special cases are:
...@@ -103,68 +178,10 @@ func Lgamma(x float64) (lgamma float64, sign int) { ...@@ -103,68 +178,10 @@ func Lgamma(x float64) (lgamma float64, sign int) {
Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15 Two53 = 1 << 53 // 0x4340000000000000 ~9.0072e+15
Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17 Two58 = 1 << 58 // 0x4390000000000000 ~2.8823e+17
Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22 Tiny = 1.0 / (1 << 70) // 0x3b90000000000000 ~8.47033e-22
A0 = 7.72156649015328655494e-02 // 0x3FB3C467E37DB0C8
A1 = 3.22467033424113591611e-01 // 0x3FD4A34CC4A60FAD
A2 = 6.73523010531292681824e-02 // 0x3FB13E001A5562A7
A3 = 2.05808084325167332806e-02 // 0x3F951322AC92547B
A4 = 7.38555086081402883957e-03 // 0x3F7E404FB68FEFE8
A5 = 2.89051383673415629091e-03 // 0x3F67ADD8CCB7926B
A6 = 1.19270763183362067845e-03 // 0x3F538A94116F3F5D
A7 = 5.10069792153511336608e-04 // 0x3F40B6C689B99C00
A8 = 2.20862790713908385557e-04 // 0x3F2CF2ECED10E54D
A9 = 1.08011567247583939954e-04 // 0x3F1C5088987DFB07
A10 = 2.52144565451257326939e-05 // 0x3EFA7074428CFA52
A11 = 4.48640949618915160150e-05 // 0x3F07858E90A45837
Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F Tc = 1.46163214496836224576e+00 // 0x3FF762D86356BE3F
Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42 Tf = -1.21486290535849611461e-01 // 0xBFBF19B9BCC38A42
// Tt = -(tail of Tf) // Tt = -(tail of Tf)
Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F Tt = -3.63867699703950536541e-18 // 0xBC50C7CAA48A971F
T0 = 4.83836122723810047042e-01 // 0x3FDEF72BC8EE38A2
T1 = -1.47587722994593911752e-01 // 0xBFC2E4278DC6C509
T2 = 6.46249402391333854778e-02 // 0x3FB08B4294D5419B
T3 = -3.27885410759859649565e-02 // 0xBFA0C9A8DF35B713
T4 = 1.79706750811820387126e-02 // 0x3F9266E7970AF9EC
T5 = -1.03142241298341437450e-02 // 0xBF851F9FBA91EC6A
T6 = 6.10053870246291332635e-03 // 0x3F78FCE0E370E344
T7 = -3.68452016781138256760e-03 // 0xBF6E2EFFB3E914D7
T8 = 2.25964780900612472250e-03 // 0x3F6282D32E15C915
T9 = -1.40346469989232843813e-03 // 0xBF56FE8EBF2D1AF1
T10 = 8.81081882437654011382e-04 // 0x3F4CDF0CEF61A8E9
T11 = -5.38595305356740546715e-04 // 0xBF41A6109C73E0EC
T12 = 3.15632070903625950361e-04 // 0x3F34AF6D6C0EBBF7
T13 = -3.12754168375120860518e-04 // 0xBF347F24ECC38C38
T14 = 3.35529192635519073543e-04 // 0x3F35FD3EE8C2D3F4
U0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
U1 = 6.32827064025093366517e-01 // 0x3FE4401E8B005DFF
U2 = 1.45492250137234768737e+00 // 0x3FF7475CD119BD6F
U3 = 9.77717527963372745603e-01 // 0x3FEF497644EA8450
U4 = 2.28963728064692451092e-01 // 0x3FCD4EAEF6010924
U5 = 1.33810918536787660377e-02 // 0x3F8B678BBF2BAB09
V1 = 2.45597793713041134822e+00 // 0x4003A5D7C2BD619C
V2 = 2.12848976379893395361e+00 // 0x40010725A42B18F5
V3 = 7.69285150456672783825e-01 // 0x3FE89DFBE45050AF
V4 = 1.04222645593369134254e-01 // 0x3FBAAE55D6537C88
V5 = 3.21709242282423911810e-03 // 0x3F6A5ABB57D0CF61
S0 = -7.72156649015328655494e-02 // 0xBFB3C467E37DB0C8
S1 = 2.14982415960608852501e-01 // 0x3FCB848B36E20878
S2 = 3.25778796408930981787e-01 // 0x3FD4D98F4F139F59
S3 = 1.46350472652464452805e-01 // 0x3FC2BB9CBEE5F2F7
S4 = 2.66422703033638609560e-02 // 0x3F9B481C7E939961
S5 = 1.84028451407337715652e-03 // 0x3F5E26B67368F239
S6 = 3.19475326584100867617e-05 // 0x3F00BFECDD17E945
R1 = 1.39200533467621045958e+00 // 0x3FF645A762C4AB74
R2 = 7.21935547567138069525e-01 // 0x3FE71A1893D3DCDC
R3 = 1.71933865632803078993e-01 // 0x3FC601EDCCFBDF27
R4 = 1.86459191715652901344e-02 // 0x3F9317EA742ED475
R5 = 7.77942496381893596434e-04 // 0x3F497DDACA41A95B
R6 = 7.32668430744625636189e-06 // 0x3EDEBAF7A5B38140
W0 = 4.18938533204672725052e-01 // 0x3FDACFE390C97D69
W1 = 8.33333333333329678849e-02 // 0x3FB555555555553B
W2 = -2.77777777728775536470e-03 // 0xBF66C16C16B02E5C
W3 = 7.93650558643019558500e-04 // 0x3F4A019F98CF38B6
W4 = -5.95187557450339963135e-04 // 0xBF4380CB8C0FE741
W5 = 8.36339918996282139126e-04 // 0x3F4B67BA4CDAD5D1
W6 = -1.63092934096575273989e-03 // 0xBF5AB89D0B9E43E4
) )
// TODO(rsc): Remove manual inlining of IsNaN, IsInf // TODO(rsc): Remove manual inlining of IsNaN, IsInf
// when compiler does it for us // when compiler does it for us
...@@ -249,28 +266,28 @@ func Lgamma(x float64) (lgamma float64, sign int) { ...@@ -249,28 +266,28 @@ func Lgamma(x float64) (lgamma float64, sign int) {
switch i { switch i {
case 0: case 0:
z := y * y z := y * y
p1 := A0 + z*(A2+z*(A4+z*(A6+z*(A8+z*A10)))) p1 := _lgamA[0] + z*(_lgamA[2]+z*(_lgamA[4]+z*(_lgamA[6]+z*(_lgamA[8]+z*_lgamA[10]))))
p2 := z * (A1 + z*(A3+z*(A5+z*(A7+z*(A9+z*A11))))) p2 := z * (_lgamA[1] + z*(+_lgamA[3]+z*(_lgamA[5]+z*(_lgamA[7]+z*(_lgamA[9]+z*_lgamA[11])))))
p := y*p1 + p2 p := y*p1 + p2
lgamma += (p - 0.5*y) lgamma += (p - 0.5*y)
case 1: case 1:
z := y * y z := y * y
w := z * y w := z * y
p1 := T0 + w*(T3+w*(T6+w*(T9+w*T12))) // parallel comp p1 := _lgamT[0] + w*(_lgamT[3]+w*(_lgamT[6]+w*(_lgamT[9]+w*_lgamT[12]))) // parallel comp
p2 := T1 + w*(T4+w*(T7+w*(T10+w*T13))) p2 := _lgamT[1] + w*(_lgamT[4]+w*(_lgamT[7]+w*(_lgamT[10]+w*_lgamT[13])))
p3 := T2 + w*(T5+w*(T8+w*(T11+w*T14))) p3 := _lgamT[2] + w*(_lgamT[5]+w*(_lgamT[8]+w*(_lgamT[11]+w*_lgamT[14])))
p := z*p1 - (Tt - w*(p2+y*p3)) p := z*p1 - (Tt - w*(p2+y*p3))
lgamma += (Tf + p) lgamma += (Tf + p)
case 2: case 2:
p1 := y * (U0 + y*(U1+y*(U2+y*(U3+y*(U4+y*U5))))) p1 := y * (_lgamU[0] + y*(_lgamU[1]+y*(_lgamU[2]+y*(_lgamU[3]+y*(_lgamU[4]+y*_lgamU[5])))))
p2 := 1 + y*(V1+y*(V2+y*(V3+y*(V4+y*V5)))) p2 := 1 + y*(_lgamV[1]+y*(_lgamV[2]+y*(_lgamV[3]+y*(_lgamV[4]+y*_lgamV[5]))))
lgamma += (-0.5*y + p1/p2) lgamma += (-0.5*y + p1/p2)
} }
case x < 8: // 2 <= x < 8 case x < 8: // 2 <= x < 8
i := int(x) i := int(x)
y := x - float64(i) y := x - float64(i)
p := y * (S0 + y*(S1+y*(S2+y*(S3+y*(S4+y*(S5+y*S6)))))) p := y * (_lgamS[0] + y*(_lgamS[1]+y*(_lgamS[2]+y*(_lgamS[3]+y*(_lgamS[4]+y*(_lgamS[5]+y*_lgamS[6]))))))
q := 1 + y*(R1+y*(R2+y*(R3+y*(R4+y*(R5+y*R6))))) q := 1 + y*(_lgamR[1]+y*(_lgamR[2]+y*(_lgamR[3]+y*(_lgamR[4]+y*(_lgamR[5]+y*_lgamR[6])))))
lgamma = 0.5*y + p/q lgamma = 0.5*y + p/q
z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s) z := 1.0 // Lgamma(1+s) = Log(s) + Lgamma(s)
switch i { switch i {
...@@ -294,7 +311,7 @@ func Lgamma(x float64) (lgamma float64, sign int) { ...@@ -294,7 +311,7 @@ func Lgamma(x float64) (lgamma float64, sign int) {
t := Log(x) t := Log(x)
z := 1 / x z := 1 / x
y := z * z y := z * z
w := W0 + z*(W1+y*(W2+y*(W3+y*(W4+y*(W5+y*W6))))) w := _lgamW[0] + z*(_lgamW[1]+y*(_lgamW[2]+y*(_lgamW[3]+y*(_lgamW[4]+y*(_lgamW[5]+y*_lgamW[6])))))
lgamma = (x-0.5)*(t-1) + w lgamma = (x-0.5)*(t-1) + w
default: // 2**58 <= x <= Inf default: // 2**58 <= x <= Inf
lgamma = x * (Log(x) - 1) lgamma = x * (Log(x) - 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment