Commit b740cb63 by Ian Lance Taylor

libgo: update to weekly.2011-10-25

Changes were mainly straightforward to merge.

From-SVN: r181824
parent cebc182b
6d7136d74b65 941b8015061a
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -236,11 +236,19 @@ toolexeclibgoencoding_DATA = \ ...@@ -236,11 +236,19 @@ toolexeclibgoencoding_DATA = \
encoding/hex.gox \ encoding/hex.gox \
encoding/pem.gox encoding/pem.gox
if LIBGO_IS_LINUX
# exp_inotify_gox = exp/inotify.gox
exp_inotify_gox =
else
exp_inotify_gox =
endif
toolexeclibgoexpdir = $(toolexeclibgodir)/exp toolexeclibgoexpdir = $(toolexeclibgodir)/exp
toolexeclibgoexp_DATA = \ toolexeclibgoexp_DATA = \
exp/ebnf.gox \ exp/ebnf.gox \
exp/gui.gox \ exp/gui.gox \
$(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/spdy.gox \ exp/spdy.gox \
exp/sql.gox \ exp/sql.gox \
...@@ -332,15 +340,7 @@ toolexeclibgoold_DATA = \ ...@@ -332,15 +340,7 @@ toolexeclibgoold_DATA = \
toolexeclibgoosdir = $(toolexeclibgodir)/os toolexeclibgoosdir = $(toolexeclibgodir)/os
if LIBGO_IS_LINUX
# os_inotify_gox = os/inotify.gox
os_inotify_gox =
else
os_inotify_gox =
endif
toolexeclibgoos_DATA = \ toolexeclibgoos_DATA = \
$(os_inotify_gox) \
os/user.gox \ os/user.gox \
os/signal.gox os/signal.gox
...@@ -1212,6 +1212,8 @@ go_exp_ebnf_files = \ ...@@ -1212,6 +1212,8 @@ go_exp_ebnf_files = \
go/exp/ebnf/parser.go go/exp/ebnf/parser.go
go_exp_gui_files = \ go_exp_gui_files = \
go/exp/gui/gui.go go/exp/gui/gui.go
go_exp_inotify_files = \
go/exp/inotify/inotify_linux.go
go_exp_norm_files = \ go_exp_norm_files = \
go/exp/norm/composition.go \ go/exp/norm/composition.go \
go/exp/norm/forminfo.go \ go/exp/norm/forminfo.go \
...@@ -1229,11 +1231,13 @@ go_exp_sql_files = \ ...@@ -1229,11 +1231,13 @@ go_exp_sql_files = \
go/exp/sql/sql.go go/exp/sql/sql.go
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/client.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
go/exp/ssh/doc.go \ go/exp/ssh/doc.go \
go/exp/ssh/messages.go \ go/exp/ssh/messages.go \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/shell.go \ go/exp/terminal/shell.go \
...@@ -1387,9 +1391,6 @@ go_old_template_files = \ ...@@ -1387,9 +1391,6 @@ go_old_template_files = \
go/old/template/format.go \ go/old/template/format.go \
go/old/template/parse.go go/old/template/parse.go
go_os_inotify_files = \
go/os/inotify/inotify_linux.go
go_os_user_files = \ go_os_user_files = \
go/os/user/user.go \ go/os/user/user.go \
go/os/user/lookup_unix.go go/os/user/lookup_unix.go
...@@ -2723,6 +2724,13 @@ exp/gui/x11/check: $(CHECK_DEPS) ...@@ -2723,6 +2724,13 @@ exp/gui/x11/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/gui/x11/check .PHONY: exp/gui/x11/check
exp/inotify.lo: $(go_exp_inotify_files) fmt.gox os.gox strings.gox syscall.gox
$(BUILDPACKAGE)
exp/inotify/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/inotify
@$(CHECK)
.PHONY: exp/inotify/check
exp/sql/driver.lo: $(go_exp_sql_driver_files) fmt.gox os.gox reflect.gox \ exp/sql/driver.lo: $(go_exp_sql_driver_files) fmt.gox os.gox reflect.gox \
strconv.gox strconv.gox
$(BUILDPACKAGE) $(BUILDPACKAGE)
...@@ -2998,13 +3006,6 @@ old/template/check: $(CHECK_DEPS) ...@@ -2998,13 +3006,6 @@ old/template/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: old/template/check .PHONY: old/template/check
os/inotify.lo: $(go_os_inotify_files) fmt.gox os.gox strings.gox syscall.gox
$(BUILDPACKAGE)
os/inotify/check: $(CHECK_DEPS)
@$(MKDIR_P) os/inotify
@$(CHECK)
.PHONY: os/inotify/check
os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \ os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \
strings.gox syscall.gox strings.gox syscall.gox
$(BUILDPACKAGE) $(BUILDPACKAGE)
...@@ -3331,6 +3332,8 @@ exp/ebnf.gox: exp/ebnf.lo ...@@ -3331,6 +3332,8 @@ exp/ebnf.gox: exp/ebnf.lo
$(BUILDGOX) $(BUILDGOX)
exp/gui.gox: exp/gui.lo exp/gui.gox: exp/gui.lo
$(BUILDGOX) $(BUILDGOX)
exp/inotify.gox: exp/inotify.lo
$(BUILDGOX)
exp/norm.gox: exp/norm.lo exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/spdy.gox: exp/spdy.lo exp/spdy.gox: exp/spdy.lo
...@@ -3424,8 +3427,6 @@ old/regexp.gox: old/regexp.lo ...@@ -3424,8 +3427,6 @@ old/regexp.gox: old/regexp.lo
old/template.gox: old/template.lo old/template.gox: old/template.lo
$(BUILDGOX) $(BUILDGOX)
os/inotify.gox: os/inotify.lo
$(BUILDGOX)
os/user.gox: os/user.lo os/user.gox: os/user.lo
$(BUILDGOX) $(BUILDGOX)
os/signal.gox: os/signal.lo os/signal.gox: os/signal.lo
...@@ -3459,10 +3460,10 @@ testing/script.gox: testing/script.lo ...@@ -3459,10 +3460,10 @@ testing/script.gox: testing/script.lo
$(BUILDGOX) $(BUILDGOX)
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
# os_inotify_check = os/inotify/check # exp_inotify_check = exp/inotify/check
os_inotify_check = exp_inotify_check =
else else
os_inotify_check = exp_inotify_check =
endif endif
TEST_PACKAGES = \ TEST_PACKAGES = \
...@@ -3563,6 +3564,7 @@ TEST_PACKAGES = \ ...@@ -3563,6 +3564,7 @@ TEST_PACKAGES = \
encoding/hex/check \ encoding/hex/check \
encoding/pem/check \ encoding/pem/check \
exp/ebnf/check \ exp/ebnf/check \
$(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/spdy/check \ exp/spdy/check \
exp/sql/check \ exp/sql/check \
...@@ -3594,7 +3596,6 @@ TEST_PACKAGES = \ ...@@ -3594,7 +3596,6 @@ TEST_PACKAGES = \
old/netchan/check \ old/netchan/check \
old/regexp/check \ old/regexp/check \
old/template/check \ old/template/check \
$(os_inotify_check) \
os/user/check \ os/user/check \
os/signal/check \ os/signal/check \
path/filepath/check \ path/filepath/check \
......
...@@ -699,10 +699,15 @@ toolexeclibgoencoding_DATA = \ ...@@ -699,10 +699,15 @@ toolexeclibgoencoding_DATA = \
encoding/hex.gox \ encoding/hex.gox \
encoding/pem.gox encoding/pem.gox
@LIBGO_IS_LINUX_FALSE@exp_inotify_gox =
# exp_inotify_gox = exp/inotify.gox
@LIBGO_IS_LINUX_TRUE@exp_inotify_gox =
toolexeclibgoexpdir = $(toolexeclibgodir)/exp toolexeclibgoexpdir = $(toolexeclibgodir)/exp
toolexeclibgoexp_DATA = \ toolexeclibgoexp_DATA = \
exp/ebnf.gox \ exp/ebnf.gox \
exp/gui.gox \ exp/gui.gox \
$(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/spdy.gox \ exp/spdy.gox \
exp/sql.gox \ exp/sql.gox \
...@@ -781,12 +786,7 @@ toolexeclibgoold_DATA = \ ...@@ -781,12 +786,7 @@ toolexeclibgoold_DATA = \
old/template.gox old/template.gox
toolexeclibgoosdir = $(toolexeclibgodir)/os toolexeclibgoosdir = $(toolexeclibgodir)/os
@LIBGO_IS_LINUX_FALSE@os_inotify_gox =
# os_inotify_gox = os/inotify.gox
@LIBGO_IS_LINUX_TRUE@os_inotify_gox =
toolexeclibgoos_DATA = \ toolexeclibgoos_DATA = \
$(os_inotify_gox) \
os/user.gox \ os/user.gox \
os/signal.gox os/signal.gox
...@@ -1579,6 +1579,9 @@ go_exp_ebnf_files = \ ...@@ -1579,6 +1579,9 @@ go_exp_ebnf_files = \
go_exp_gui_files = \ go_exp_gui_files = \
go/exp/gui/gui.go go/exp/gui/gui.go
go_exp_inotify_files = \
go/exp/inotify/inotify_linux.go
go_exp_norm_files = \ go_exp_norm_files = \
go/exp/norm/composition.go \ go/exp/norm/composition.go \
go/exp/norm/forminfo.go \ go/exp/norm/forminfo.go \
...@@ -1599,11 +1602,13 @@ go_exp_sql_files = \ ...@@ -1599,11 +1602,13 @@ go_exp_sql_files = \
go_exp_ssh_files = \ go_exp_ssh_files = \
go/exp/ssh/channel.go \ go/exp/ssh/channel.go \
go/exp/ssh/client.go \
go/exp/ssh/common.go \ go/exp/ssh/common.go \
go/exp/ssh/doc.go \ go/exp/ssh/doc.go \
go/exp/ssh/messages.go \ go/exp/ssh/messages.go \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_shell.go \
go/exp/ssh/session.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
go_exp_terminal_files = \ go_exp_terminal_files = \
...@@ -1773,9 +1778,6 @@ go_old_template_files = \ ...@@ -1773,9 +1778,6 @@ go_old_template_files = \
go/old/template/format.go \ go/old/template/format.go \
go/old/template/parse.go go/old/template/parse.go
go_os_inotify_files = \
go/os/inotify/inotify_linux.go
go_os_user_files = \ go_os_user_files = \
go/os/user/user.go \ go/os/user/user.go \
go/os/user/lookup_unix.go go/os/user/lookup_unix.go
...@@ -2171,10 +2173,10 @@ BUILDGOX = \ ...@@ -2171,10 +2173,10 @@ BUILDGOX = \
f=`echo $< | sed -e 's/.lo$$/.o/'`; \ f=`echo $< | sed -e 's/.lo$$/.o/'`; \
$(OBJCOPY) -j .go_export $$f $@.tmp && mv -f $@.tmp $@ $(OBJCOPY) -j .go_export $$f $@.tmp && mv -f $@.tmp $@
@LIBGO_IS_LINUX_FALSE@os_inotify_check = @LIBGO_IS_LINUX_FALSE@exp_inotify_check =
# os_inotify_check = os/inotify/check # exp_inotify_check = exp/inotify/check
@LIBGO_IS_LINUX_TRUE@os_inotify_check = @LIBGO_IS_LINUX_TRUE@exp_inotify_check =
TEST_PACKAGES = \ TEST_PACKAGES = \
asn1/check \ asn1/check \
big/check \ big/check \
...@@ -2273,6 +2275,7 @@ TEST_PACKAGES = \ ...@@ -2273,6 +2275,7 @@ TEST_PACKAGES = \
encoding/hex/check \ encoding/hex/check \
encoding/pem/check \ encoding/pem/check \
exp/ebnf/check \ exp/ebnf/check \
$(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/spdy/check \ exp/spdy/check \
exp/sql/check \ exp/sql/check \
...@@ -2304,7 +2307,6 @@ TEST_PACKAGES = \ ...@@ -2304,7 +2307,6 @@ TEST_PACKAGES = \
old/netchan/check \ old/netchan/check \
old/regexp/check \ old/regexp/check \
old/template/check \ old/template/check \
$(os_inotify_check) \
os/user/check \ os/user/check \
os/signal/check \ os/signal/check \
path/filepath/check \ path/filepath/check \
...@@ -5326,6 +5328,13 @@ exp/gui/x11/check: $(CHECK_DEPS) ...@@ -5326,6 +5328,13 @@ exp/gui/x11/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/gui/x11/check .PHONY: exp/gui/x11/check
exp/inotify.lo: $(go_exp_inotify_files) fmt.gox os.gox strings.gox syscall.gox
$(BUILDPACKAGE)
exp/inotify/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/inotify
@$(CHECK)
.PHONY: exp/inotify/check
exp/sql/driver.lo: $(go_exp_sql_driver_files) fmt.gox os.gox reflect.gox \ exp/sql/driver.lo: $(go_exp_sql_driver_files) fmt.gox os.gox reflect.gox \
strconv.gox strconv.gox
$(BUILDPACKAGE) $(BUILDPACKAGE)
...@@ -5601,13 +5610,6 @@ old/template/check: $(CHECK_DEPS) ...@@ -5601,13 +5610,6 @@ old/template/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: old/template/check .PHONY: old/template/check
os/inotify.lo: $(go_os_inotify_files) fmt.gox os.gox strings.gox syscall.gox
$(BUILDPACKAGE)
os/inotify/check: $(CHECK_DEPS)
@$(MKDIR_P) os/inotify
@$(CHECK)
.PHONY: os/inotify/check
os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \ os/user.lo: $(go_os_user_files) fmt.gox os.gox runtime.gox strconv.gox \
strings.gox syscall.gox strings.gox syscall.gox
$(BUILDPACKAGE) $(BUILDPACKAGE)
...@@ -5929,6 +5931,8 @@ exp/ebnf.gox: exp/ebnf.lo ...@@ -5929,6 +5931,8 @@ exp/ebnf.gox: exp/ebnf.lo
$(BUILDGOX) $(BUILDGOX)
exp/gui.gox: exp/gui.lo exp/gui.gox: exp/gui.lo
$(BUILDGOX) $(BUILDGOX)
exp/inotify.gox: exp/inotify.lo
$(BUILDGOX)
exp/norm.gox: exp/norm.lo exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/spdy.gox: exp/spdy.lo exp/spdy.gox: exp/spdy.lo
...@@ -6022,8 +6026,6 @@ old/regexp.gox: old/regexp.lo ...@@ -6022,8 +6026,6 @@ old/regexp.gox: old/regexp.lo
old/template.gox: old/template.lo old/template.gox: old/template.lo
$(BUILDGOX) $(BUILDGOX)
os/inotify.gox: os/inotify.lo
$(BUILDGOX)
os/user.gox: os/user.lo os/user.gox: os/user.lo
$(BUILDGOX) $(BUILDGOX)
os/signal.gox: os/signal.lo os/signal.gox: os/signal.lo
......
...@@ -58,22 +58,24 @@ func NewInt(x int64) *Int { ...@@ -58,22 +58,24 @@ func NewInt(x int64) *Int {
// Set sets z to x and returns z. // Set sets z to x and returns z.
func (z *Int) Set(x *Int) *Int { func (z *Int) Set(x *Int) *Int {
if z != x {
z.abs = z.abs.set(x.abs) z.abs = z.abs.set(x.abs)
z.neg = x.neg z.neg = x.neg
}
return z return z
} }
// Abs sets z to |x| (the absolute value of x) and returns z. // Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Int) Abs(x *Int) *Int { func (z *Int) Abs(x *Int) *Int {
z.abs = z.abs.set(x.abs) z.Set(x)
z.neg = false z.neg = false
return z return z
} }
// Neg sets z to -x and returns z. // Neg sets z to -x and returns z.
func (z *Int) Neg(x *Int) *Int { func (z *Int) Neg(x *Int) *Int {
z.abs = z.abs.set(x.abs) z.Set(x)
z.neg = len(z.abs) > 0 && !x.neg // 0 has no sign z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign
return z return z
} }
...@@ -174,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int { ...@@ -174,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int {
// If y == 0, a division-by-zero run-time panic occurs. // If y == 0, a division-by-zero run-time panic occurs.
// Rem implements truncated modulus (like Go); see QuoRem for more details. // Rem implements truncated modulus (like Go); see QuoRem for more details.
func (z *Int) Rem(x, y *Int) *Int { func (z *Int) Rem(x, y *Int) *Int {
_, z.abs = nat(nil).div(z.abs, x.abs, y.abs) _, z.abs = nat{}.div(z.abs, x.abs, y.abs)
z.neg = len(z.abs) > 0 && x.neg // 0 has no sign z.neg = len(z.abs) > 0 && x.neg // 0 has no sign
return z return z
} }
...@@ -422,8 +424,8 @@ func (x *Int) Format(s fmt.State, ch int) { ...@@ -422,8 +424,8 @@ func (x *Int) Format(s fmt.State, ch int) {
// scan sets z to the integer value corresponding to the longest possible prefix // scan sets z to the integer value corresponding to the longest possible prefix
// read from r representing a signed integer number in a given conversion base. // read from r representing a signed integer number in a given conversion base.
// It returns z, the actual conversion base used, and an error, if any. In the // It returns z, the actual conversion base used, and an error, if any. In the
// error case, the value of z is undefined. The syntax follows the syntax of // error case, the value of z is undefined but the returned value is nil. The
// integer literals in Go. // syntax follows the syntax of integer literals in Go.
// //
// The base argument must be 0 or a value from 2 through MaxBase. If the base // The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of // is 0, the string prefix determines the actual conversion base. A prefix of
...@@ -434,7 +436,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) { ...@@ -434,7 +436,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) {
// determine sign // determine sign
ch, _, err := r.ReadRune() ch, _, err := r.ReadRune()
if err != nil { if err != nil {
return z, 0, err return nil, 0, err
} }
neg := false neg := false
switch ch { switch ch {
...@@ -448,7 +450,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) { ...@@ -448,7 +450,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) {
// determine mantissa // determine mantissa
z.abs, base, err = z.abs.scan(r, base) z.abs, base, err = z.abs.scan(r, base)
if err != nil { if err != nil {
return z, base, err return nil, base, err
} }
z.neg = len(z.abs) > 0 && neg // 0 has no sign z.neg = len(z.abs) > 0 && neg // 0 has no sign
...@@ -497,7 +499,7 @@ func (x *Int) Int64() int64 { ...@@ -497,7 +499,7 @@ func (x *Int) Int64() int64 {
// SetString sets z to the value of s, interpreted in the given base, // SetString sets z to the value of s, interpreted in the given base,
// and returns z and a boolean indicating success. If SetString fails, // and returns z and a boolean indicating success. If SetString fails,
// the value of z is undefined. // the value of z is undefined but the returned value is nil.
// //
// The base argument must be 0 or a value from 2 through MaxBase. If the base // The base argument must be 0 or a value from 2 through MaxBase. If the base
// is 0, the string prefix determines the actual conversion base. A prefix of // is 0, the string prefix determines the actual conversion base. A prefix of
...@@ -508,10 +510,13 @@ func (z *Int) SetString(s string, base int) (*Int, bool) { ...@@ -508,10 +510,13 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
r := strings.NewReader(s) r := strings.NewReader(s)
_, _, err := z.scan(r, base) _, _, err := z.scan(r, base)
if err != nil { if err != nil {
return z, false return nil, false
} }
_, _, err = r.ReadRune() _, _, err = r.ReadRune()
return z, err == os.EOF // err == os.EOF => scan consumed all of s if err != os.EOF {
return nil, false
}
return z, true // err == os.EOF => scan consumed all of s
} }
// SetBytes interprets buf as the bytes of a big-endian unsigned // SetBytes interprets buf as the bytes of a big-endian unsigned
......
...@@ -311,7 +311,16 @@ func TestSetString(t *testing.T) { ...@@ -311,7 +311,16 @@ func TestSetString(t *testing.T) {
t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok) t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok)
continue continue
} }
if !ok1 || !ok2 { if !ok1 {
if n1 != nil {
t.Errorf("#%d (input '%s') n1 != nil", i, test.in)
}
continue
}
if !ok2 {
if n2 != nil {
t.Errorf("#%d (input '%s') n2 != nil", i, test.in)
}
continue continue
} }
......
...@@ -35,7 +35,7 @@ import ( ...@@ -35,7 +35,7 @@ import (
// During arithmetic operations, denormalized values may occur but are // During arithmetic operations, denormalized values may occur but are
// always normalized before returning the final result. The normalized // always normalized before returning the final result. The normalized
// representation of 0 is the empty or nil slice (length = 0). // representation of 0 is the empty or nil slice (length = 0).
//
type nat []Word type nat []Word
var ( var (
...@@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat { ...@@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat {
case a == b: case a == b:
return z.setUint64(a) return z.setUint64(a)
case a+1 == b: case a+1 == b:
return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) return z.mul(nat{}.setUint64(a), nat{}.setUint64(b))
} }
m := (a + b) / 2 m := (a + b) / 2
return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b))
} }
// q = (x-r)/y, with 0 <= r < y // q = (x-r)/y, with 0 <= r < y
...@@ -589,7 +589,6 @@ func (x nat) bitLen() int { ...@@ -589,7 +589,6 @@ func (x nat) bitLen() int {
// MaxBase is the largest number base accepted for string conversions. // MaxBase is the largest number base accepted for string conversions.
const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1 const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1
func hexValue(ch int) Word { func hexValue(ch int) Word {
d := MaxBase + 1 // illegal base d := MaxBase + 1 // illegal base
switch { switch {
...@@ -786,7 +785,7 @@ func (x nat) string(charset string) string { ...@@ -786,7 +785,7 @@ func (x nat) string(charset string) string {
} }
// preserve x, create local copy for use in repeated divisions // preserve x, create local copy for use in repeated divisions
q := nat(nil).set(x) q := nat{}.set(x)
var r Word var r Word
// convert // convert
...@@ -1192,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool { ...@@ -1192,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool {
return false return false
} }
nm1 := nat(nil).sub(n, natOne) nm1 := nat{}.sub(n, natOne)
// 1<<k * q = nm1; // 1<<k * q = nm1;
q, k := nm1.powersOfTwoDecompose() q, k := nm1.powersOfTwoDecompose()
nm3 := nat(nil).sub(nm1, natTwo) nm3 := nat{}.sub(nm1, natTwo)
rand := rand.New(rand.NewSource(int64(n[0]))) rand := rand.New(rand.NewSource(int64(n[0])))
var x, y, quotient nat var x, y, quotient nat
......
...@@ -67,7 +67,7 @@ var prodNN = []argNN{ ...@@ -67,7 +67,7 @@ var prodNN = []argNN{
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
for _, a := range sumNN { for _, a := range sumNN {
z := nat(nil).set(a.z) z := nat{}.set(a.z)
if z.cmp(a.z) != 0 { if z.cmp(a.z) != 0 {
t.Errorf("got z = %v; want %v", z, a.z) t.Errorf("got z = %v; want %v", z, a.z)
} }
...@@ -129,7 +129,7 @@ var mulRangesN = []struct { ...@@ -129,7 +129,7 @@ var mulRangesN = []struct {
func TestMulRangeN(t *testing.T) { func TestMulRangeN(t *testing.T) {
for i, r := range mulRangesN { for i, r := range mulRangesN {
prod := nat(nil).mulRange(r.a, r.b).decimalString() prod := nat{}.mulRange(r.a, r.b).decimalString()
if prod != r.prod { if prod != r.prod {
t.Errorf("#%d: got %s; want %s", i, prod, r.prod) t.Errorf("#%d: got %s; want %s", i, prod, r.prod)
} }
...@@ -175,7 +175,7 @@ func toString(x nat, charset string) string { ...@@ -175,7 +175,7 @@ func toString(x nat, charset string) string {
s := make([]byte, i) s := make([]byte, i)
// don't destroy x // don't destroy x
q := nat(nil).set(x) q := nat{}.set(x)
// convert // convert
for len(q) > 0 { for len(q) > 0 {
...@@ -212,7 +212,7 @@ func TestString(t *testing.T) { ...@@ -212,7 +212,7 @@ func TestString(t *testing.T) {
t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s)
} }
x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c)) x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c))
if x.cmp(a.x) != 0 { if x.cmp(a.x) != 0 {
t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x)
} }
...@@ -271,7 +271,7 @@ var natScanTests = []struct { ...@@ -271,7 +271,7 @@ var natScanTests = []struct {
func TestScanBase(t *testing.T) { func TestScanBase(t *testing.T) {
for _, a := range natScanTests { for _, a := range natScanTests {
r := strings.NewReader(a.s) r := strings.NewReader(a.s)
x, b, err := nat(nil).scan(r, a.base) x, b, err := nat{}.scan(r, a.base)
if err == nil && !a.ok { if err == nil && !a.ok {
t.Errorf("scan%+v\n\texpected error", a) t.Errorf("scan%+v\n\texpected error", a)
} }
...@@ -651,17 +651,17 @@ var expNNTests = []struct { ...@@ -651,17 +651,17 @@ var expNNTests = []struct {
func TestExpNN(t *testing.T) { func TestExpNN(t *testing.T) {
for i, test := range expNNTests { for i, test := range expNNTests {
x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0) x, _, _ := nat{}.scan(strings.NewReader(test.x), 0)
y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0) y, _, _ := nat{}.scan(strings.NewReader(test.y), 0)
out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0) out, _, _ := nat{}.scan(strings.NewReader(test.out), 0)
var m nat var m nat
if len(test.m) > 0 { if len(test.m) > 0 {
m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0) m, _, _ = nat{}.scan(strings.NewReader(test.m), 0)
} }
z := nat(nil).expNN(x, y, m) z := nat{}.expNN(x, y, m)
if z.cmp(out) != 0 { if z.cmp(out) != 0 {
t.Errorf("#%d got %v want %v", i, z, out) t.Errorf("#%d got %v want %v", i, z, out)
} }
......
...@@ -13,11 +13,11 @@ import ( ...@@ -13,11 +13,11 @@ import (
"strings" "strings"
) )
// A Rat represents a quotient a/b of arbitrary precision. The zero value for // A Rat represents a quotient a/b of arbitrary precision.
// a Rat, 0/0, is not a legal Rat. // The zero value for a Rat represents the value 0.
type Rat struct { type Rat struct {
a Int a Int
b nat b nat // len(b) == 0 acts like b == 1
} }
// NewRat creates a new Rat with numerator a and denominator b. // NewRat creates a new Rat with numerator a and denominator b.
...@@ -29,8 +29,11 @@ func NewRat(a, b int64) *Rat { ...@@ -29,8 +29,11 @@ func NewRat(a, b int64) *Rat {
func (z *Rat) SetFrac(a, b *Int) *Rat { func (z *Rat) SetFrac(a, b *Int) *Rat {
z.a.neg = a.neg != b.neg z.a.neg = a.neg != b.neg
babs := b.abs babs := b.abs
if len(babs) == 0 {
panic("division by zero")
}
if &z.a == b || alias(z.a.abs, babs) { if &z.a == b || alias(z.a.abs, babs) {
babs = nat(nil).set(babs) // make a copy babs = nat{}.set(babs) // make a copy
} }
z.a.abs = z.a.abs.set(a.abs) z.a.abs = z.a.abs.set(a.abs)
z.b = z.b.set(babs) z.b = z.b.set(babs)
...@@ -40,6 +43,9 @@ func (z *Rat) SetFrac(a, b *Int) *Rat { ...@@ -40,6 +43,9 @@ func (z *Rat) SetFrac(a, b *Int) *Rat {
// SetFrac64 sets z to a/b and returns z. // SetFrac64 sets z to a/b and returns z.
func (z *Rat) SetFrac64(a, b int64) *Rat { func (z *Rat) SetFrac64(a, b int64) *Rat {
z.a.SetInt64(a) z.a.SetInt64(a)
if b == 0 {
panic("division by zero")
}
if b < 0 { if b < 0 {
b = -b b = -b
z.a.neg = !z.a.neg z.a.neg = !z.a.neg
...@@ -51,14 +57,55 @@ func (z *Rat) SetFrac64(a, b int64) *Rat { ...@@ -51,14 +57,55 @@ func (z *Rat) SetFrac64(a, b int64) *Rat {
// SetInt sets z to x (by making a copy of x) and returns z. // SetInt sets z to x (by making a copy of x) and returns z.
func (z *Rat) SetInt(x *Int) *Rat { func (z *Rat) SetInt(x *Int) *Rat {
z.a.Set(x) z.a.Set(x)
z.b = z.b.setWord(1) z.b = z.b.make(0)
return z return z
} }
// SetInt64 sets z to x and returns z. // SetInt64 sets z to x and returns z.
func (z *Rat) SetInt64(x int64) *Rat { func (z *Rat) SetInt64(x int64) *Rat {
z.a.SetInt64(x) z.a.SetInt64(x)
z.b = z.b.setWord(1) z.b = z.b.make(0)
return z
}
// Set sets z to x (by making a copy of x) and returns z.
func (z *Rat) Set(x *Rat) *Rat {
if z != x {
z.a.Set(&x.a)
z.b = z.b.set(x.b)
}
return z
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat {
z.Set(x)
z.a.neg = false
return z
}
// Neg sets z to -x and returns z.
func (z *Rat) Neg(x *Rat) *Rat {
z.Set(x)
z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
return z
}
// Inv sets z to 1/x and returns z.
func (z *Rat) Inv(x *Rat) *Rat {
if len(x.a.abs) == 0 {
panic("division by zero")
}
z.Set(x)
a := z.b
if len(a) == 0 {
a = a.setWord(1) // materialize numerator
}
b := z.a.abs
if b.cmp(natOne) == 0 {
b = b.make(0) // normalize denominator
}
z.a.abs, z.b = a, b // sign doesn't change
return z return z
} }
...@@ -74,21 +121,24 @@ func (x *Rat) Sign() int { ...@@ -74,21 +121,24 @@ func (x *Rat) Sign() int {
// IsInt returns true if the denominator of x is 1. // IsInt returns true if the denominator of x is 1.
func (x *Rat) IsInt() bool { func (x *Rat) IsInt() bool {
return len(x.b) == 1 && x.b[0] == 1 return len(x.b) == 0 || x.b.cmp(natOne) == 0
} }
// Num returns the numerator of z; it may be <= 0. // Num returns the numerator of x; it may be <= 0.
// The result is a reference to z's numerator; it // The result is a reference to x's numerator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to x.
func (z *Rat) Num() *Int { func (x *Rat) Num() *Int {
return &z.a return &x.a
} }
// Denom returns the denominator of z; it is always > 0. // Denom returns the denominator of x; it is always > 0.
// The result is a reference to z's denominator; it // The result is a reference to x's denominator; it
// may change if a new value is assigned to z. // may change if a new value is assigned to x.
func (z *Rat) Denom() *Int { func (x *Rat) Denom() *Int {
return &Int{false, z.b} if len(x.b) == 0 {
return &Int{abs: nat{1}}
}
return &Int{abs: x.b}
} }
func gcd(x, y nat) nat { func gcd(x, y nat) nat {
...@@ -106,24 +156,47 @@ func gcd(x, y nat) nat { ...@@ -106,24 +156,47 @@ func gcd(x, y nat) nat {
} }
func (z *Rat) norm() *Rat { func (z *Rat) norm() *Rat {
f := gcd(z.a.abs, z.b) switch {
if len(z.a.abs) == 0 { case len(z.a.abs) == 0:
// z == 0 // z == 0 - normalize sign and denominator
z.a.neg = false // normalize sign z.a.neg = false
z.b = z.b.setWord(1) z.b = z.b.make(0)
return z case len(z.b) == 0:
} // z is normalized int - nothing to do
if f.cmp(natOne) != 0 { case z.b.cmp(natOne) == 0:
// z is int - normalize denominator
z.b = z.b.make(0)
default:
if f := gcd(z.a.abs, z.b); f.cmp(natOne) != 0 {
z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f) z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f)
z.b, _ = z.b.div(nil, z.b, f) z.b, _ = z.b.div(nil, z.b, f)
} }
}
return z return z
} }
func mulNat(x *Int, y nat) *Int { // mulDenom sets z to the denominator product x*y (by taking into
// account that 0 values for x or y must be interpreted as 1) and
// returns z.
func mulDenom(z, x, y nat) nat {
switch {
case len(x) == 0:
return z.set(y)
case len(y) == 0:
return z.set(x)
}
return z.mul(x, y)
}
// scaleDenom computes x*f.
// If f == 0 (zero value of denominator), the result is (a copy of) x.
func scaleDenom(x *Int, f nat) *Int {
var z Int var z Int
z.abs = z.abs.mul(x.abs, y) if len(f) == 0 {
z.neg = len(z.abs) > 0 && x.neg return z.Set(x)
}
z.abs = z.abs.mul(x.abs, f)
z.neg = x.neg
return &z return &z
} }
...@@ -133,39 +206,32 @@ func mulNat(x *Int, y nat) *Int { ...@@ -133,39 +206,32 @@ func mulNat(x *Int, y nat) *Int {
// 0 if x == y // 0 if x == y
// +1 if x > y // +1 if x > y
// //
func (x *Rat) Cmp(y *Rat) (r int) { func (x *Rat) Cmp(y *Rat) int {
return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b)) return scaleDenom(&x.a, y.b).Cmp(scaleDenom(&y.a, x.b))
}
// Abs sets z to |x| (the absolute value of x) and returns z.
func (z *Rat) Abs(x *Rat) *Rat {
z.a.Abs(&x.a)
z.b = z.b.set(x.b)
return z
} }
// Add sets z to the sum x+y and returns z. // Add sets z to the sum x+y and returns z.
func (z *Rat) Add(x, y *Rat) *Rat { func (z *Rat) Add(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := scaleDenom(&x.a, y.b)
a2 := mulNat(&y.a, x.b) a2 := scaleDenom(&y.a, x.b)
z.a.Add(a1, a2) z.a.Add(a1, a2)
z.b = z.b.mul(x.b, y.b) z.b = mulDenom(z.b, x.b, y.b)
return z.norm() return z.norm()
} }
// Sub sets z to the difference x-y and returns z. // Sub sets z to the difference x-y and returns z.
func (z *Rat) Sub(x, y *Rat) *Rat { func (z *Rat) Sub(x, y *Rat) *Rat {
a1 := mulNat(&x.a, y.b) a1 := scaleDenom(&x.a, y.b)
a2 := mulNat(&y.a, x.b) a2 := scaleDenom(&y.a, x.b)
z.a.Sub(a1, a2) z.a.Sub(a1, a2)
z.b = z.b.mul(x.b, y.b) z.b = mulDenom(z.b, x.b, y.b)
return z.norm() return z.norm()
} }
// Mul sets z to the product x*y and returns z. // Mul sets z to the product x*y and returns z.
func (z *Rat) Mul(x, y *Rat) *Rat { func (z *Rat) Mul(x, y *Rat) *Rat {
z.a.Mul(&x.a, &y.a) z.a.Mul(&x.a, &y.a)
z.b = z.b.mul(x.b, y.b) z.b = mulDenom(z.b, x.b, y.b)
return z.norm() return z.norm()
} }
...@@ -175,28 +241,14 @@ func (z *Rat) Quo(x, y *Rat) *Rat { ...@@ -175,28 +241,14 @@ func (z *Rat) Quo(x, y *Rat) *Rat {
if len(y.a.abs) == 0 { if len(y.a.abs) == 0 {
panic("division by zero") panic("division by zero")
} }
a := mulNat(&x.a, y.b) a := scaleDenom(&x.a, y.b)
b := mulNat(&y.a, x.b) b := scaleDenom(&y.a, x.b)
z.a.abs = a.abs z.a.abs = a.abs
z.b = b.abs z.b = b.abs
z.a.neg = a.neg != b.neg z.a.neg = a.neg != b.neg
return z.norm() return z.norm()
} }
// Neg sets z to -x (by making a copy of x if necessary) and returns z.
func (z *Rat) Neg(x *Rat) *Rat {
z.a.Neg(&x.a)
z.b = z.b.set(x.b)
return z
}
// Set sets z to x (by making a copy of x if necessary) and returns z.
func (z *Rat) Set(x *Rat) *Rat {
z.a.Set(&x.a)
z.b = z.b.set(x.b)
return z
}
func ratTok(ch int) bool { func ratTok(ch int) bool {
return strings.IndexRune("+-/0123456789.eE", ch) >= 0 return strings.IndexRune("+-/0123456789.eE", ch) >= 0
} }
...@@ -219,23 +271,23 @@ func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error { ...@@ -219,23 +271,23 @@ func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error {
// SetString sets z to the value of s and returns z and a boolean indicating // SetString sets z to the value of s and returns z and a boolean indicating
// success. s can be given as a fraction "a/b" or as a floating-point number // success. s can be given as a fraction "a/b" or as a floating-point number
// optionally followed by an exponent. If the operation failed, the value of z // optionally followed by an exponent. If the operation failed, the value of
// is undefined. // z is undefined but the returned value is nil.
func (z *Rat) SetString(s string) (*Rat, bool) { func (z *Rat) SetString(s string) (*Rat, bool) {
if len(s) == 0 { if len(s) == 0 {
return z, false return nil, false
} }
// check for a quotient // check for a quotient
sep := strings.Index(s, "/") sep := strings.Index(s, "/")
if sep >= 0 { if sep >= 0 {
if _, ok := z.a.SetString(s[0:sep], 10); !ok { if _, ok := z.a.SetString(s[0:sep], 10); !ok {
return z, false return nil, false
} }
s = s[sep+1:] s = s[sep+1:]
var err os.Error var err os.Error
if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil { if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil {
return z, false return nil, false
} }
return z.norm(), true return z.norm(), true
} }
...@@ -248,10 +300,10 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -248,10 +300,10 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
if e >= 0 { if e >= 0 {
if e < sep { if e < sep {
// The E must come after the decimal point. // The E must come after the decimal point.
return z, false return nil, false
} }
if _, ok := exp.SetString(s[e+1:], 10); !ok { if _, ok := exp.SetString(s[e+1:], 10); !ok {
return z, false return nil, false
} }
s = s[0:e] s = s[0:e]
} }
...@@ -261,7 +313,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -261,7 +313,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
} }
if _, ok := z.a.SetString(s, 10); !ok { if _, ok := z.a.SetString(s, 10); !ok {
return z, false return nil, false
} }
powTen := nat{}.expNN(natTen, exp.abs, nil) powTen := nat{}.expNN(natTen, exp.abs, nil)
if exp.neg { if exp.neg {
...@@ -269,7 +321,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -269,7 +321,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
z.norm() z.norm()
} else { } else {
z.a.abs = z.a.abs.mul(z.a.abs, powTen) z.a.abs = z.a.abs.mul(z.a.abs, powTen)
z.b = z.b.setWord(1) z.b = z.b.make(0)
} }
return z, true return z, true
...@@ -277,7 +329,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) { ...@@ -277,7 +329,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) {
// String returns a string representation of z in the form "a/b" (even if b == 1). // String returns a string representation of z in the form "a/b" (even if b == 1).
func (z *Rat) String() string { func (z *Rat) String() string {
return z.a.String() + "/" + z.b.decimalString() s := "/1"
if len(z.b) != 0 {
s = "/" + z.b.decimalString()
}
return z.a.String() + s
} }
// RatString returns a string representation of z in the form "a/b" if b != 1, // RatString returns a string representation of z in the form "a/b" if b != 1,
...@@ -299,6 +355,7 @@ func (z *Rat) FloatString(prec int) string { ...@@ -299,6 +355,7 @@ func (z *Rat) FloatString(prec int) string {
} }
return s return s
} }
// z.b != 0
q, r := nat{}.div(nat{}, z.a.abs, z.b) q, r := nat{}.div(nat{}, z.a.abs, z.b)
......
...@@ -11,6 +11,46 @@ import ( ...@@ -11,6 +11,46 @@ import (
"testing" "testing"
) )
func TestZeroRat(t *testing.T) {
var x, y, z Rat
y.SetFrac64(0, 42)
if x.Cmp(&y) != 0 {
t.Errorf("x and y should be both equal and zero")
}
if s := x.String(); s != "0/1" {
t.Errorf("got x = %s, want 0/1", s)
}
if s := x.RatString(); s != "0" {
t.Errorf("got x = %s, want 0", s)
}
z.Add(&x, &y)
if s := z.RatString(); s != "0" {
t.Errorf("got x+y = %s, want 0", s)
}
z.Sub(&x, &y)
if s := z.RatString(); s != "0" {
t.Errorf("got x-y = %s, want 0", s)
}
z.Mul(&x, &y)
if s := z.RatString(); s != "0" {
t.Errorf("got x*y = %s, want 0", s)
}
// check for division by zero
defer func() {
if s := recover(); s == nil || s.(string) != "division by zero" {
panic(s)
}
}()
z.Quo(&x, &y)
}
var setStringTests = []struct { var setStringTests = []struct {
in, out string in, out string
ok bool ok bool
...@@ -50,8 +90,14 @@ func TestRatSetString(t *testing.T) { ...@@ -50,8 +90,14 @@ func TestRatSetString(t *testing.T) {
for i, test := range setStringTests { for i, test := range setStringTests {
x, ok := new(Rat).SetString(test.in) x, ok := new(Rat).SetString(test.in)
if ok != test.ok || ok && x.RatString() != test.out { if ok {
t.Errorf("#%d got %s want %s", i, x.RatString(), test.out) if !test.ok {
t.Errorf("#%d SetString(%q) expected failure", i, test.in)
} else if x.RatString() != test.out {
t.Errorf("#%d SetString(%q) got %s want %s", i, test.in, x.RatString(), test.out)
}
} else if x != nil {
t.Errorf("#%d SetString(%q) got %p want nil", i, test.in, x)
} }
} }
} }
...@@ -113,8 +159,10 @@ func TestFloatString(t *testing.T) { ...@@ -113,8 +159,10 @@ func TestFloatString(t *testing.T) {
func TestRatSign(t *testing.T) { func TestRatSign(t *testing.T) {
zero := NewRat(0, 1) zero := NewRat(0, 1)
for _, a := range setStringTests { for _, a := range setStringTests {
var x Rat x, ok := new(Rat).SetString(a.in)
x.SetString(a.in) if !ok {
continue
}
s := x.Sign() s := x.Sign()
e := x.Cmp(zero) e := x.Cmp(zero)
if s != e { if s != e {
...@@ -153,29 +201,65 @@ func TestRatCmp(t *testing.T) { ...@@ -153,29 +201,65 @@ func TestRatCmp(t *testing.T) {
func TestIsInt(t *testing.T) { func TestIsInt(t *testing.T) {
one := NewInt(1) one := NewInt(1)
for _, a := range setStringTests { for _, a := range setStringTests {
var x Rat x, ok := new(Rat).SetString(a.in)
x.SetString(a.in) if !ok {
continue
}
i := x.IsInt() i := x.IsInt()
e := x.Denom().Cmp(one) == 0 e := x.Denom().Cmp(one) == 0
if i != e { if i != e {
t.Errorf("got %v; want %v for z = %v", i, e, &x) t.Errorf("got IsInt(%v) == %v; want %v", x, i, e)
} }
} }
} }
func TestRatAbs(t *testing.T) { func TestRatAbs(t *testing.T) {
zero := NewRat(0, 1) zero := new(Rat)
for _, a := range setStringTests { for _, a := range setStringTests {
var z Rat x, ok := new(Rat).SetString(a.in)
z.SetString(a.in) if !ok {
var e Rat continue
e.Set(&z) }
e := new(Rat).Set(x)
if e.Cmp(zero) < 0 { if e.Cmp(zero) < 0 {
e.Sub(zero, &e) e.Sub(zero, e)
}
z := new(Rat).Abs(x)
if z.Cmp(e) != 0 {
t.Errorf("got Abs(%v) = %v; want %v", x, z, e)
}
}
}
func TestRatNeg(t *testing.T) {
zero := new(Rat)
for _, a := range setStringTests {
x, ok := new(Rat).SetString(a.in)
if !ok {
continue
}
e := new(Rat).Sub(zero, x)
z := new(Rat).Neg(x)
if z.Cmp(e) != 0 {
t.Errorf("got Neg(%v) = %v; want %v", x, z, e)
}
}
}
func TestRatInv(t *testing.T) {
zero := new(Rat)
for _, a := range setStringTests {
x, ok := new(Rat).SetString(a.in)
if !ok {
continue
}
if x.Cmp(zero) == 0 {
continue // avoid division by zero
} }
z.Abs(&z) e := new(Rat).SetFrac(x.Denom(), x.Num())
if z.Cmp(&e) != 0 { z := new(Rat).Inv(x)
t.Errorf("got z = %v; want %v", &z, &e) if z.Cmp(e) != 0 {
t.Errorf("got Inv(%v) = %v; want %v", x, z, e)
} }
} }
} }
...@@ -186,10 +270,10 @@ type ratBinArg struct { ...@@ -186,10 +270,10 @@ type ratBinArg struct {
} }
func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) { func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) {
x, _ := NewRat(0, 1).SetString(a.x) x, _ := new(Rat).SetString(a.x)
y, _ := NewRat(0, 1).SetString(a.y) y, _ := new(Rat).SetString(a.y)
z, _ := NewRat(0, 1).SetString(a.z) z, _ := new(Rat).SetString(a.z)
out := f(NewRat(0, 1), x, y) out := f(new(Rat), x, y)
if out.Cmp(z) != 0 { if out.Cmp(z) != 0 {
t.Errorf("%s #%d got %s want %s", name, i, out, z) t.Errorf("%s #%d got %s want %s", name, i, out, z)
......
...@@ -928,11 +928,11 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P ...@@ -928,11 +928,11 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P
return return
} }
asn1Issuer, err := asn1.Marshal(parent.Issuer.ToRDNSequence()) asn1Issuer, err := asn1.Marshal(parent.Subject.ToRDNSequence())
if err != nil { if err != nil {
return return
} }
asn1Subject, err := asn1.Marshal(parent.Subject.ToRDNSequence()) asn1Subject, err := asn1.Marshal(template.Subject.ToRDNSequence())
if err != nil { if err != nil {
return return
} }
......
...@@ -6,8 +6,8 @@ package x509 ...@@ -6,8 +6,8 @@ package x509
import ( import (
"asn1" "asn1"
"bytes"
"big" "big"
"bytes"
"crypto/dsa" "crypto/dsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
...@@ -243,10 +243,11 @@ func TestCreateSelfSignedCertificate(t *testing.T) { ...@@ -243,10 +243,11 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
return return
} }
commonName := "test.example.com"
template := Certificate{ template := Certificate{
SerialNumber: big.NewInt(1), SerialNumber: big.NewInt(1),
Subject: pkix.Name{ Subject: pkix.Name{
CommonName: "test.example.com", CommonName: commonName,
Organization: []string{"Acme Co"}, Organization: []string{"Acme Co"},
}, },
NotBefore: time.SecondsToUTC(1000), NotBefore: time.SecondsToUTC(1000),
...@@ -283,6 +284,14 @@ func TestCreateSelfSignedCertificate(t *testing.T) { ...@@ -283,6 +284,14 @@ func TestCreateSelfSignedCertificate(t *testing.T) {
t.Errorf("Failed to parse name constraints: %#v", cert.PermittedDNSDomains) t.Errorf("Failed to parse name constraints: %#v", cert.PermittedDNSDomains)
} }
if cert.Subject.CommonName != commonName {
t.Errorf("Subject wasn't correctly copied from the template. Got %s, want %s", cert.Subject.CommonName, commonName)
}
if cert.Issuer.CommonName != commonName {
t.Errorf("Issuer wasn't correctly copied from the template. Got %s, want %s", cert.Issuer.CommonName, commonName)
}
err = cert.CheckSignatureFrom(cert) err = cert.CheckSignatureFrom(cert)
if err != nil { if err != nil {
t.Errorf("Signature verification failed: %s", err) t.Errorf("Signature verification failed: %s", err)
......
...@@ -6,8 +6,8 @@ package inotify ...@@ -6,8 +6,8 @@ package inotify
import ( import (
"os" "os"
"time"
"testing" "testing"
"time"
) )
func TestInotifyEvents(t *testing.T) { func TestInotifyEvents(t *testing.T) {
......
...@@ -68,7 +68,7 @@ type channel struct { ...@@ -68,7 +68,7 @@ type channel struct {
weClosed bool weClosed bool
dead bool dead bool
serverConn *ServerConnection serverConn *ServerConn
myId, theirId uint32 myId, theirId uint32
myWindow, theirWindow uint32 myWindow, theirWindow uint32
maxPacketSize uint32 maxPacketSize uint32
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"big"
"crypto"
"crypto/rand"
"fmt"
"io"
"net"
"os"
"sync"
)
// clientVersion is the fixed identification string that the client will use.
var clientVersion = []byte("SSH-2.0-Go\r\n")
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
*transport
config *ClientConfig
chanlist
}
// Client returns a new SSH client connection using c as the underlying transport.
func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) {
conn := &ClientConn{
transport: newTransport(c, config.rand()),
config: config,
}
if err := conn.handshake(); err != nil {
conn.Close()
return nil, err
}
if err := conn.authenticate(); err != nil {
conn.Close()
return nil, err
}
go conn.mainLoop()
return conn, nil
}
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() os.Error {
var magics handshakeMagics
if _, err := c.Write(clientVersion); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
}
magics.clientVersion = clientVersion[:len(clientVersion)-2]
// read remote server version
version, err := readVersion(c)
if err != nil {
return err
}
magics.serverVersion = version
clientKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: supportedCiphers,
CiphersServerClient: supportedCiphers,
MACsClientServer: supportedMACs,
MACsServerClient: supportedMACs,
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, clientKexInit)
magics.clientKexInit = kexInitPacket
if err := c.writePacket(kexInitPacket); err != nil {
return err
}
packet, err := c.readPacket()
if err != nil {
return err
}
magics.serverKexInit = packet
var serverKexInit kexInitMsg
if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
return err
}
kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
if !ok {
return os.NewError("ssh: no common algorithms")
}
if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
// The server sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err := c.readPacket(); err != nil {
return err
}
}
var H, K []byte
var hashFunc crypto.Hash
switch kexAlgo {
case kexAlgoDH14SHA1:
hashFunc = crypto.SHA1
dhGroup14Once.Do(initDHGroup14)
H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
default:
err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
}
if err != nil {
return err
}
if err = c.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
return err
}
if packet, err = c.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
}
// authenticate authenticates with the remote server. See RFC 4252.
// Only "password" authentication is supported.
func (c *ClientConn) authenticate() os.Error {
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.readPacket()
if err != nil {
return err
}
var serviceAccept serviceAcceptMsg
if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
return err
}
// TODO(dfc) support proper authentication method negotation
method := "none"
if c.config.Password != "" {
method = "password"
}
if err := c.sendUserAuthReq(method); err != nil {
return err
}
if packet, err = c.readPacket(); err != nil {
return err
}
if packet[0] != msgUserAuthSuccess {
return UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
}
return nil
}
func (c *ClientConn) sendUserAuthReq(method string) os.Error {
length := stringLength([]byte(c.config.Password)) + 1
payload := make([]byte, length)
// always false for password auth, see RFC 4252 Section 8.
payload[0] = 0
marshalString(payload[1:], []byte(c.config.Password))
return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
User: c.config.User,
Service: serviceSSH,
Method: method,
Payload: payload,
}))
}
// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
// returned values are given the same names as in RFC 4253, section 8.
func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) {
x, err := rand.Int(c.config.rand(), group.p)
if err != nil {
return nil, nil, err
}
X := new(big.Int).Exp(group.g, x, group.p)
kexDHInit := kexDHInitMsg{
X: X,
}
if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
return nil, nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, nil, err
}
var kexDHReply = new(kexDHReplyMsg)
if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
return nil, nil, err
}
if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
return nil, nil, os.NewError("server DH parameter out of bounds")
}
kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
h := hashFunc.New()
writeString(h, magics.clientVersion)
writeString(h, magics.serverVersion)
writeString(h, magics.clientKexInit)
writeString(h, magics.serverKexInit)
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
h.Write(K)
H := h.Sum()
return H, K, nil
}
// openChan opens a new client channel. The most common session type is "session".
// The full set of valid session types are listed in RFC 4250 4.9.1.
func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
ChanType: typ,
PeersId: ch.id,
PeersWindow: 1 << 14,
MaxPacketSize: 1 << 15, // RFC 4253 6.1
})); err != nil {
c.chanlist.remove(ch.id)
return nil, err
}
// wait for response
switch msg := (<-ch.msg).(type) {
case *channelOpenConfirmMsg:
ch.peersId = msg.MyId
case *channelOpenFailureMsg:
c.chanlist.remove(ch.id)
return nil, os.NewError(msg.Message)
default:
c.chanlist.remove(ch.id)
return nil, os.NewError("Unexpected packet")
}
return ch, nil
}
// mainloop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
for {
packet, err := c.readPacket()
if err != nil {
// TODO(dfc) signal the underlying close to all channels
c.Close()
return
}
// TODO(dfc) A note on blocking channel use.
// The msg, win, data and dataExt channels of a clientChan can
// cause this loop to block indefinately if the consumer does
// not service them.
switch msg := decode(packet).(type) {
case *channelOpenMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenConfirmMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelCloseMsg:
ch := c.getChan(msg.PeersId)
close(ch.win)
close(ch.data)
close(ch.dataExt)
c.chanlist.remove(msg.PeersId)
case *channelEOFMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestSuccessMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestMsg:
c.getChan(msg.PeersId).msg <- msg
case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
case *channelData:
c.getChan(msg.PeersId).data <- msg.Payload
case *channelExtendedData:
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if msg.Datatype == 1 {
c.getChan(msg.PeersId).dataExt <- msg.Payload
}
default:
fmt.Printf("mainLoop: unhandled %#v\n", msg)
}
}
}
// Dial connects to the given network address using net.Dial and
// then initiates a SSH handshake, returning the resulting client connection.
func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
return Client(conn, config)
}
// A ClientConfig structure is used to configure a ClientConn. After one has
// been passed to an SSH function it must not be modified.
type ClientConfig struct {
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
// be used.
Rand io.Reader
// The username to authenticate.
User string
// Used for "password" method authentication.
Password string
}
func (c *ClientConfig) rand() io.Reader {
if c.Rand == nil {
return rand.Reader
}
return c.Rand
}
// A clientChan represents a single RFC 4254 channel that is multiplexed
// over a single SSH connection.
type clientChan struct {
packetWriter
id, peersId uint32
data chan []byte // receives the payload of channelData messages
dataExt chan []byte // receives the payload of channelExtendedData messages
win chan int // receives window adjustments
msg chan interface{} // incoming messages
}
func newClientChan(t *transport, id uint32) *clientChan {
return &clientChan{
packetWriter: t,
id: id,
data: make(chan []byte, 16),
dataExt: make(chan []byte, 16),
win: make(chan int, 16),
msg: make(chan interface{}, 16),
}
}
// Close closes the channel. This does not close the underlying connection.
func (c *clientChan) Close() os.Error {
return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
PeersId: c.id,
}))
}
func (c *clientChan) sendChanReq(req channelRequestMsg) os.Error {
if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
return err
}
msg := <-c.msg
if _, ok := msg.(*channelRequestSuccessMsg); ok {
return nil
}
return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg)
}
// Thread safe channel list.
type chanlist struct {
// protects concurrent access to chans
sync.Mutex
// chans are indexed by the local id of the channel, clientChan.id.
// The PeersId value of messages received by ClientConn.mainloop is
// used to locate the right local clientChan in this slice.
chans []*clientChan
}
// Allocate a new ClientChan with the next avail local id.
func (c *chanlist) newChan(t *transport) *clientChan {
c.Lock()
defer c.Unlock()
for i := range c.chans {
if c.chans[i] == nil {
ch := newClientChan(t, uint32(i))
c.chans[i] = ch
return ch
}
}
i := len(c.chans)
ch := newClientChan(t, uint32(i))
c.chans = append(c.chans, ch)
return ch
}
func (c *chanlist) getChan(id uint32) *clientChan {
c.Lock()
defer c.Unlock()
return c.chans[int(id)]
}
func (c *chanlist) remove(id uint32) {
c.Lock()
defer c.Unlock()
c.chans[int(id)] = nil
}
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
win chan int // receives window adjustments
id uint32 // this channel's id
rwin int // current rwin size
packetWriter // for sending channelDataMsg
}
// Write writes data to the remote process's standard input.
func (w *chanWriter) Write(data []byte) (n int, err os.Error) {
for {
if w.rwin == 0 {
win, ok := <-w.win
if !ok {
return 0, os.EOF
}
w.rwin += win
continue
}
n = len(data)
packet := make([]byte, 0, 9+n)
packet = append(packet, msgChannelData,
byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
err = w.writePacket(append(packet, data...))
w.rwin -= n
return
}
panic("unreachable")
}
func (w *chanWriter) Close() os.Error {
return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
}
// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
// TODO(dfc) a fixed size channel may not be the right data structure.
// If writes to this channel block, they will block mainLoop, making
// it unable to receive new messages from the remote side.
data chan []byte // receives data from remote
id uint32
packetWriter // for sending windowAdjustMsg
buf []byte
}
// Read reads data from the remote process's stdout or stderr.
func (r *chanReader) Read(data []byte) (int, os.Error) {
var ok bool
for {
if len(r.buf) > 0 {
n := copy(data, r.buf)
r.buf = r.buf[n:]
msg := windowAdjustMsg{
PeersId: r.id,
AdditionalBytes: uint32(n),
}
return n, r.writePacket(marshal(msgChannelWindowAdjust, msg))
}
r.buf, ok = <-r.data
if !ok {
return 0, os.EOF
}
}
panic("unreachable")
}
func (r *chanReader) Close() os.Error {
return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
}
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
/* /*
Package ssh implements an SSH server. Package ssh implements an SSH client and server.
SSH is a transport security protocol, an authentication protocol and a SSH is a transport security protocol, an authentication protocol and a
family of application protocols. The most typical application level family of application protocols. The most typical application level
...@@ -11,26 +11,29 @@ protocol is a remote shell and this is specifically implemented. However, ...@@ -11,26 +11,29 @@ protocol is a remote shell and this is specifically implemented. However,
the multiplexed nature of SSH is exposed to users that wish to support the multiplexed nature of SSH is exposed to users that wish to support
others. others.
An SSH server is represented by a Server, which manages a number of An SSH server is represented by a ServerConfig, which holds certificate
ServerConnections and handles authentication. details and handles authentication of ServerConns.
var s Server config := new(ServerConfig)
s.PubKeyCallback = pubKeyAuth config.PubKeyCallback = pubKeyAuth
s.PasswordCallback = passwordAuth config.PasswordCallback = passwordAuth
pemBytes, err := ioutil.ReadFile("id_rsa") pemBytes, err := ioutil.ReadFile("id_rsa")
if err != nil { if err != nil {
panic("Failed to load private key") panic("Failed to load private key")
} }
err = s.SetRSAPrivateKey(pemBytes) err = config.SetRSAPrivateKey(pemBytes)
if err != nil { if err != nil {
panic("Failed to parse private key") panic("Failed to parse private key")
} }
Once a Server has been set up, connections can be attached. Once a ServerConfig has been configured, connections can be accepted.
var sConn ServerConnection listener := Listen("tcp", "0.0.0.0:2022", config)
sConn.Server = &s sConn, err := listener.Accept()
if err != nil {
panic("failed to accept incoming connection")
}
err = sConn.Handshake(conn) err = sConn.Handshake(conn)
if err != nil { if err != nil {
panic("failed to handshake") panic("failed to handshake")
...@@ -38,7 +41,6 @@ Once a Server has been set up, connections can be attached. ...@@ -38,7 +41,6 @@ Once a Server has been set up, connections can be attached.
An SSH connection multiplexes several channels, which must be accepted themselves: An SSH connection multiplexes several channels, which must be accepted themselves:
for { for {
channel, err := sConn.Accept() channel, err := sConn.Accept()
if err != nil { if err != nil {
...@@ -75,5 +77,29 @@ present a simple terminal interface. ...@@ -75,5 +77,29 @@ present a simple terminal interface.
} }
return return
}() }()
An SSH client is represented with a ClientConn. Currently only the "password"
authentication method is supported.
config := &ClientConfig{
User: "username",
Password: "123456",
}
client, err := Dial("yourserver.com:22", config)
Each ClientConn can support multiple interactive sessions, represented by a Session.
session, err := client.NewSession()
Once a Session is created, you can execute a single command on the remote side
using the Exec method.
if err := session.Exec("/usr/bin/whoami"); err != nil {
panic("Failed to exec: " + err.String())
}
reader := bufio.NewReader(session.Stdin)
line, _, _ := reader.ReadLine()
fmt.Println(line)
session.Close()
*/ */
package ssh package ssh
...@@ -154,7 +154,7 @@ type channelData struct { ...@@ -154,7 +154,7 @@ type channelData struct {
type channelExtendedData struct { type channelExtendedData struct {
PeersId uint32 PeersId uint32
Datatype uint32 Datatype uint32
Data string Payload []byte `ssh:"rest"`
} }
type channelRequestMsg struct { type channelRequestMsg struct {
......
...@@ -10,19 +10,23 @@ import ( ...@@ -10,19 +10,23 @@ import (
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
_ "crypto/sha1"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"io"
"net" "net"
"os" "os"
"sync" "sync"
) )
// Server represents an SSH server. A Server may have several ServerConnections. type ServerConfig struct {
type Server struct {
rsa *rsa.PrivateKey rsa *rsa.PrivateKey
rsaSerialized []byte rsaSerialized []byte
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
// be used.
Rand io.Reader
// NoClientAuth is true if clients are allowed to connect without // NoClientAuth is true if clients are allowed to connect without
// authenticating. // authenticating.
NoClientAuth bool NoClientAuth bool
...@@ -38,11 +42,18 @@ type Server struct { ...@@ -38,11 +42,18 @@ type Server struct {
PubKeyCallback func(user, algo string, pubkey []byte) bool PubKeyCallback func(user, algo string, pubkey []byte) bool
} }
func (c *ServerConfig) rand() io.Reader {
if c.Rand == nil {
return rand.Reader
}
return c.Rand
}
// SetRSAPrivateKey sets the private key for a Server. A Server must have a // SetRSAPrivateKey sets the private key for a Server. A Server must have a
// private key configured in order to accept connections. The private key must // private key configured in order to accept connections. The private key must
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
// typically contains such a key. // typically contains such a key.
func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) os.Error {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {
return os.NewError("ssh: no key found") return os.NewError("ssh: no key found")
...@@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) { ...@@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) {
} }
// cachedPubKey contains the results of querying whether a public key is // cachedPubKey contains the results of querying whether a public key is
// acceptable for a user. The cache only applies to a single ServerConnection. // acceptable for a user. The cache only applies to a single ServerConn.
type cachedPubKey struct { type cachedPubKey struct {
user, algo string user, algo string
pubKey []byte pubKey []byte
...@@ -118,11 +129,10 @@ type cachedPubKey struct { ...@@ -118,11 +129,10 @@ type cachedPubKey struct {
const maxCachedPubKeys = 16 const maxCachedPubKeys = 16
// ServerConnection represents an incomming connection to a Server. // A ServerConn represents an incomming connection.
type ServerConnection struct { type ServerConn struct {
Server *Server
*transport *transport
config *ServerConfig
channels map[uint32]*channel channels map[uint32]*channel
nextChanId uint32 nextChanId uint32
...@@ -139,9 +149,20 @@ type ServerConnection struct { ...@@ -139,9 +149,20 @@ type ServerConnection struct {
cachedPubKeys []cachedPubKey cachedPubKeys []cachedPubKey
} }
// Server returns a new SSH server connection
// using c as the underlying transport.
func Server(c net.Conn, config *ServerConfig) *ServerConn {
conn := &ServerConn{
transport: newTransport(c, config.rand()),
channels: make(map[uint32]*channel),
config: config,
}
return conn
}
// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
// returned values are given the same names as in RFC 4253, section 8. // returned values are given the same names as in RFC 4253, section 8.
func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
packet, err := s.readPacket() packet, err := s.readPacket()
if err != nil { if err != nil {
return return
...@@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
return nil, nil, os.NewError("client DH parameter out of bounds") return nil, nil, os.NewError("client DH parameter out of bounds")
} }
y, err := rand.Int(rand.Reader, group.p) y, err := rand.Int(s.config.rand(), group.p)
if err != nil { if err != nil {
return return
} }
...@@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
var serializedHostKey []byte var serializedHostKey []byte
switch hostKeyAlgo { switch hostKeyAlgo {
case hostAlgoRSA: case hostAlgoRSA:
serializedHostKey = s.Server.rsaSerialized serializedHostKey = s.config.rsaSerialized
default: default:
return nil, nil, os.NewError("internal error") return nil, nil, os.NewError("internal error")
} }
...@@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h ...@@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h
var sig []byte var sig []byte
switch hostKeyAlgo { switch hostKeyAlgo {
case hostAlgoRSA: case hostAlgoRSA:
sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh)
if err != nil { if err != nil {
return return
} }
...@@ -257,19 +278,20 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK ...@@ -257,19 +278,20 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
return ret return ret
} }
// Handshake performs an SSH transport and client authentication on the given ServerConnection. // Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConnection) Handshake(conn net.Conn) os.Error { func (s *ServerConn) Handshake() os.Error {
var magics handshakeMagics var magics handshakeMagics
s.transport = newTransport(conn, rand.Reader) if _, err := s.Write(serverVersion); err != nil {
return err
if _, err := conn.Write(serverVersion); err != nil { }
if err := s.Flush(); err != nil {
return err return err
} }
magics.serverVersion = serverVersion[:len(serverVersion)-2] magics.serverVersion = serverVersion[:len(serverVersion)-2]
version, ok := readVersion(s.transport) version, err := readVersion(s)
if !ok { if err != nil {
return os.NewError("failed to read version string from client") return err
} }
magics.clientVersion = version magics.clientVersion = version
...@@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm, // The client sent a Kex message for the wrong algorithm,
// which we have to ignore. // which we have to ignore.
_, err := s.readPacket() if _, err := s.readPacket(); err != nil {
if err != nil {
return err return err
} }
} }
...@@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
dhGroup14Once.Do(initDHGroup14) dhGroup14Once.Do(initDHGroup14)
H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
default: default:
err = os.NewError("ssh: internal error") err = os.NewError("ssh: unexpected key exchange algorithm " + kexAlgo)
} }
if err != nil { if err != nil {
return err return err
} }
packet = []byte{msgNewKeys} if err = s.writePacket([]byte{msgNewKeys}); err != nil {
if err = s.writePacket(packet); err != nil {
return err return err
} }
if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err return err
} }
if packet, err = s.readPacket(); err != nil { if packet, err = s.readPacket(); err != nil {
return err return err
} }
if packet[0] != msgNewKeys { if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]} return UnexpectedMessageError{msgNewKeys, packet[0]}
} }
s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
if packet, err = s.readPacket(); err != nil {
packet, err = s.readPacket()
if err != nil {
return err return err
} }
...@@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { ...@@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
if serviceRequest.Service != serviceUserAuth { if serviceRequest.Service != serviceUserAuth {
return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
} }
serviceAccept := serviceAcceptMsg{ serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth, Service: serviceUserAuth,
} }
packet = marshal(msgServiceAccept, serviceAccept) if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
if err = s.writePacket(packet); err != nil {
return err return err
} }
if err = s.authenticate(H); err != nil { if err = s.authenticate(H); err != nil {
return err return err
} }
s.channels = make(map[uint32]*channel)
return nil return nil
} }
...@@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool { ...@@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool {
} }
// testPubKey returns true if the given public key is acceptable for the user. // testPubKey returns true if the given public key is acceptable for the user.
func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool {
if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
return false return false
} }
...@@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { ...@@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
} }
} }
result := s.Server.PubKeyCallback(user, algo, pubKey) result := s.config.PubKeyCallback(user, algo, pubKey)
if len(s.cachedPubKeys) < maxCachedPubKeys { if len(s.cachedPubKeys) < maxCachedPubKeys {
c := cachedPubKey{ c := cachedPubKey{
user: user, user: user,
...@@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { ...@@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
return result return result
} }
func (s *ServerConnection) authenticate(H []byte) os.Error { func (s *ServerConn) authenticate(H []byte) os.Error {
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg
var err os.Error var err os.Error
var packet []byte var packet []byte
...@@ -428,11 +440,11 @@ userAuthLoop: ...@@ -428,11 +440,11 @@ userAuthLoop:
switch userAuthReq.Method { switch userAuthReq.Method {
case "none": case "none":
if s.Server.NoClientAuth { if s.config.NoClientAuth {
break userAuthLoop break userAuthLoop
} }
case "password": case "password":
if s.Server.PasswordCallback == nil { if s.config.PasswordCallback == nil {
break break
} }
payload := userAuthReq.Payload payload := userAuthReq.Payload
...@@ -445,11 +457,11 @@ userAuthLoop: ...@@ -445,11 +457,11 @@ userAuthLoop:
return ParseError{msgUserAuthRequest} return ParseError{msgUserAuthRequest}
} }
if s.Server.PasswordCallback(userAuthReq.User, string(password)) { if s.config.PasswordCallback(userAuthReq.User, string(password)) {
break userAuthLoop break userAuthLoop
} }
case "publickey": case "publickey":
if s.Server.PubKeyCallback == nil { if s.config.PubKeyCallback == nil {
break break
} }
payload := userAuthReq.Payload payload := userAuthReq.Payload
...@@ -520,10 +532,10 @@ userAuthLoop: ...@@ -520,10 +532,10 @@ userAuthLoop:
} }
var failureMsg userAuthFailureMsg var failureMsg userAuthFailureMsg
if s.Server.PasswordCallback != nil { if s.config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")
} }
if s.Server.PubKeyCallback != nil { if s.config.PubKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey") failureMsg.Methods = append(failureMsg.Methods, "publickey")
} }
...@@ -546,9 +558,9 @@ userAuthLoop: ...@@ -546,9 +558,9 @@ userAuthLoop:
const defaultWindowSize = 32768 const defaultWindowSize = 32768
// Accept reads and processes messages on a ServerConnection. It must be called // Accept reads and processes messages on a ServerConn. It must be called
// in order to demultiplex messages to any resulting Channels. // in order to demultiplex messages to any resulting Channels.
func (s *ServerConnection) Accept() (Channel, os.Error) { func (s *ServerConn) Accept() (Channel, os.Error) {
if s.err != nil { if s.err != nil {
return nil, s.err return nil, s.err
} }
...@@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { ...@@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) {
panic("unreachable") panic("unreachable")
} }
// A Listener implements a network listener (net.Listener) for SSH connections.
type Listener struct {
listener net.Listener
config *ServerConfig
}
// Accept waits for and returns the next incoming SSH connection.
// The receiver should call Handshake() in another goroutine
// to avoid blocking the accepter.
func (l *Listener) Accept() (*ServerConn, os.Error) {
c, err := l.listener.Accept()
if err != nil {
return nil, err
}
conn := Server(c, l.config)
return conn, nil
}
// Addr returns the listener's network address.
func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
}
// Close closes the listener.
func (l *Listener) Close() os.Error {
return l.listener.Close()
}
// Listen creates an SSH listener accepting connections on
// the given network address using net.Listen.
func Listen(network, addr string, config *ServerConfig) (*Listener, os.Error) {
l, err := net.Listen(network, addr)
if err != nil {
return nil, err
}
return &Listener{
l,
config,
}, nil
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
// Session implements an interactive session described in
// "RFC 4254, section 6".
import (
"encoding/binary"
"io"
"os"
)
// A Session represents a connection to a remote command or shell.
type Session struct {
// Writes to Stdin are made available to the remote command's standard input.
// Closing Stdin causes the command to observe an EOF on its standard input.
Stdin io.WriteCloser
// Reads from Stdout and Stderr consume from the remote command's standard
// output and error streams, respectively.
// There is a fixed amount of buffering that is shared for the two streams.
// Failing to read from either may eventually cause the command to block.
// Closing Stdout unblocks such writes and causes them to return errors.
Stdout io.ReadCloser
Stderr io.Reader
*clientChan // the channel backing this session
started bool // started is set to true once a Shell or Exec is invoked.
}
// Setenv sets an environment variable that will be applied to any
// command executed by Shell or Exec.
func (s *Session) Setenv(name, value string) os.Error {
n, v := []byte(name), []byte(value)
nlen, vlen := stringLength(n), stringLength(v)
payload := make([]byte, nlen+vlen)
marshalString(payload[:nlen], n)
marshalString(payload[nlen:], v)
return s.sendChanReq(channelRequestMsg{
PeersId: s.id,
Request: "env",
WantReply: true,
RequestSpecificData: payload,
})
}
// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8.
var emptyModeList = []byte{0, 0, 0, 1, 0}
// RequestPty requests the association of a pty with the session on the remote host.
func (s *Session) RequestPty(term string, h, w int) os.Error {
buf := make([]byte, 4+len(term)+16+len(emptyModeList))
b := marshalString(buf, []byte(term))
binary.BigEndian.PutUint32(b, uint32(h))
binary.BigEndian.PutUint32(b[4:], uint32(w))
binary.BigEndian.PutUint32(b[8:], uint32(h*8))
binary.BigEndian.PutUint32(b[12:], uint32(w*8))
copy(b[16:], emptyModeList)
return s.sendChanReq(channelRequestMsg{
PeersId: s.id,
Request: "pty-req",
WantReply: true,
RequestSpecificData: buf,
})
}
// Exec runs cmd on the remote host. Typically, the remote
// server passes cmd to the shell for interpretation.
// A Session only accepts one call to Exec or Shell.
func (s *Session) Exec(cmd string) os.Error {
if s.started {
return os.NewError("session already started")
}
cmdLen := stringLength([]byte(cmd))
payload := make([]byte, cmdLen)
marshalString(payload, []byte(cmd))
s.started = true
return s.sendChanReq(channelRequestMsg{
PeersId: s.id,
Request: "exec",
WantReply: true,
RequestSpecificData: payload,
})
}
// Shell starts a login shell on the remote host. A Session only
// accepts one call to Exec or Shell.
func (s *Session) Shell() os.Error {
if s.started {
return os.NewError("session already started")
}
s.started = true
return s.sendChanReq(channelRequestMsg{
PeersId: s.id,
Request: "shell",
WantReply: true,
})
}
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, os.Error) {
ch, err := c.openChan("session")
if err != nil {
return nil, err
}
return &Session{
Stdin: &chanWriter{
packetWriter: ch,
id: ch.id,
win: ch.win,
},
Stdout: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.data,
},
Stderr: &chanReader{
packetWriter: ch,
id: ch.id,
data: ch.dataExt,
},
clientChan: ch,
}, nil
}
...@@ -332,16 +332,15 @@ func (t truncatingMAC) Size() int { ...@@ -332,16 +332,15 @@ func (t truncatingMAC) Size() int {
const maxVersionStringBytes = 1024 const maxVersionStringBytes = 1024
// Read version string as specified by RFC 4253, section 4.2. // Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) (versionString []byte, ok bool) { func readVersion(r io.Reader) ([]byte, os.Error) {
versionString = make([]byte, 0, 64) versionString := make([]byte, 0, 64)
seenCR := false var ok, seenCR bool
var buf [1]byte var buf [1]byte
forEachByte: forEachByte:
for len(versionString) < maxVersionStringBytes { for len(versionString) < maxVersionStringBytes {
_, err := io.ReadFull(r, buf[:]) _, err := io.ReadFull(r, buf[:])
if err != nil { if err != nil {
return return nil, err
} }
b := buf[0] b := buf[0]
...@@ -360,10 +359,10 @@ forEachByte: ...@@ -360,10 +359,10 @@ forEachByte:
versionString = append(versionString, b) versionString = append(versionString, b)
} }
if ok { if !ok {
// We need to remove the CR from versionString return nil, os.NewError("failed to read version string")
versionString = versionString[:len(versionString)-1]
} }
return // We need to remove the CR from versionString
return versionString[:len(versionString)-1], nil
} }
...@@ -12,9 +12,9 @@ import ( ...@@ -12,9 +12,9 @@ import (
func TestReadVersion(t *testing.T) { func TestReadVersion(t *testing.T) {
buf := []byte(serverVersion) buf := []byte(serverVersion)
result, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
if !ok { if err != nil {
t.Error("readVersion didn't read version correctly") t.Errorf("readVersion didn't read version correctly: %s", err)
} }
if !bytes.Equal(buf[:len(buf)-2], result) { if !bytes.Equal(buf[:len(buf)-2], result) {
t.Error("version read did not match expected") t.Error("version read did not match expected")
...@@ -23,7 +23,7 @@ func TestReadVersion(t *testing.T) { ...@@ -23,7 +23,7 @@ func TestReadVersion(t *testing.T) {
func TestReadVersionTooLong(t *testing.T) { func TestReadVersionTooLong(t *testing.T) {
buf := make([]byte, maxVersionStringBytes+1) buf := make([]byte, maxVersionStringBytes+1)
if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Errorf("readVersion consumed %d bytes without error", len(buf)) t.Errorf("readVersion consumed %d bytes without error", len(buf))
} }
} }
...@@ -31,7 +31,7 @@ func TestReadVersionTooLong(t *testing.T) { ...@@ -31,7 +31,7 @@ func TestReadVersionTooLong(t *testing.T) {
func TestReadVersionWithoutCRLF(t *testing.T) { func TestReadVersionWithoutCRLF(t *testing.T) {
buf := []byte(serverVersion) buf := []byte(serverVersion)
buf = buf[:len(buf)-1] buf = buf[:len(buf)-1]
if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Error("readVersion did not notice \\n was missing") t.Error("readVersion did not notice \\n was missing")
} }
} }
...@@ -289,9 +289,10 @@ func (p *gcParser) parseExportedName() (*ast.Object, string) { ...@@ -289,9 +289,10 @@ func (p *gcParser) parseExportedName() (*ast.Object, string) {
// BasicType = identifier . // BasicType = identifier .
// //
func (p *gcParser) parseBasicType() Type { func (p *gcParser) parseBasicType() Type {
obj := Universe.Lookup(p.expect(scanner.Ident)) id := p.expect(scanner.Ident)
obj := Universe.Lookup(id)
if obj == nil || obj.Kind != ast.Typ { if obj == nil || obj.Kind != ast.Typ {
p.errorf("not a basic type: %s", obj.Name) p.errorf("not a basic type: %s", id)
} }
return obj.Type.(Type) return obj.Type.(Type)
} }
......
...@@ -6,8 +6,8 @@ package winfsnotify ...@@ -6,8 +6,8 @@ package winfsnotify
import ( import (
"os" "os"
"time"
"testing" "testing"
"time"
) )
func expect(t *testing.T, eventstream <-chan *Event, name string, mask uint32) { func expect(t *testing.T, eventstream <-chan *Event, name string, mask uint32) {
...@@ -70,15 +70,11 @@ func TestNotifyEvents(t *testing.T) { ...@@ -70,15 +70,11 @@ func TestNotifyEvents(t *testing.T) {
if _, err = file.WriteString("hello, world"); err != nil { if _, err = file.WriteString("hello, world"); err != nil {
t.Fatalf("failed to write to test file: %s", err) t.Fatalf("failed to write to test file: %s", err)
} }
if err = file.Sync(); err != nil {
t.Fatalf("failed to sync test file: %s", err)
}
expect(t, watcher.Event, testFile, FS_MODIFY)
expect(t, watcher.Event, testFile, FS_MODIFY)
if err = file.Close(); err != nil { if err = file.Close(); err != nil {
t.Fatalf("failed to close test file: %s", err) t.Fatalf("failed to close test file: %s", err)
} }
expect(t, watcher.Event, testFile, FS_MODIFY)
expect(t, watcher.Event, testFile, FS_MODIFY)
if err = os.Rename(testFile, testFile2); err != nil { if err = os.Rename(testFile, testFile2); err != nil {
t.Fatalf("failed to rename test file: %s", err) t.Fatalf("failed to rename test file: %s", err)
......
...@@ -88,6 +88,10 @@ type S struct { ...@@ -88,6 +88,10 @@ type S struct {
G G // a struct field that GoStrings G G // a struct field that GoStrings
} }
type SI struct {
I interface{}
}
// A type with a String method with pointer receiver for testing %p // A type with a String method with pointer receiver for testing %p
type P int type P int
...@@ -352,6 +356,7 @@ var fmttests = []struct { ...@@ -352,6 +356,7 @@ var fmttests = []struct {
{"%#v", map[string]int{"a": 1}, `map[string] int{"a":1}`}, {"%#v", map[string]int{"a": 1}, `map[string] int{"a":1}`},
{"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`}, {"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`},
{"%#v", []string{"a", "b"}, `[]string{"a", "b"}`}, {"%#v", []string{"a", "b"}, `[]string{"a", "b"}`},
{"%#v", SI{}, `fmt_test.SI{I:interface { }(nil)}`},
// slices with other formats // slices with other formats
{"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`}, {"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`},
......
...@@ -74,6 +74,8 @@ type pp struct { ...@@ -74,6 +74,8 @@ type pp struct {
n int n int
panicking bool panicking bool
buf bytes.Buffer buf bytes.Buffer
// field holds the current item, as an interface{}.
field interface{}
// value holds the current item, as a reflect.Value, and will be // value holds the current item, as a reflect.Value, and will be
// the zero Value if the item has not been reflected. // the zero Value if the item has not been reflected.
value reflect.Value value reflect.Value
...@@ -132,6 +134,7 @@ func (p *pp) free() { ...@@ -132,6 +134,7 @@ func (p *pp) free() {
return return
} }
p.buf.Reset() p.buf.Reset()
p.field = nil
p.value = reflect.Value{} p.value = reflect.Value{}
ppFree.put(p) ppFree.put(p)
} }
...@@ -294,16 +297,16 @@ func (p *pp) unknownType(v interface{}) { ...@@ -294,16 +297,16 @@ func (p *pp) unknownType(v interface{}) {
p.buf.WriteByte('?') p.buf.WriteByte('?')
} }
func (p *pp) badVerb(verb int, val interface{}) { func (p *pp) badVerb(verb int) {
p.add('%') p.add('%')
p.add('!') p.add('!')
p.add(verb) p.add(verb)
p.add('(') p.add('(')
switch { switch {
case val != nil: case p.field != nil:
p.buf.WriteString(reflect.TypeOf(val).String()) p.buf.WriteString(reflect.TypeOf(p.field).String())
p.add('=') p.add('=')
p.printField(val, 'v', false, false, 0) p.printField(p.field, 'v', false, false, 0)
case p.value.IsValid(): case p.value.IsValid():
p.buf.WriteString(p.value.Type().String()) p.buf.WriteString(p.value.Type().String())
p.add('=') p.add('=')
...@@ -314,12 +317,12 @@ func (p *pp) badVerb(verb int, val interface{}) { ...@@ -314,12 +317,12 @@ func (p *pp) badVerb(verb int, val interface{}) {
p.add(')') p.add(')')
} }
func (p *pp) fmtBool(v bool, verb int, value interface{}) { func (p *pp) fmtBool(v bool, verb int) {
switch verb { switch verb {
case 't', 'v': case 't', 'v':
p.fmt.fmt_boolean(v) p.fmt.fmt_boolean(v)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
...@@ -333,7 +336,7 @@ func (p *pp) fmtC(c int64) { ...@@ -333,7 +336,7 @@ func (p *pp) fmtC(c int64) {
p.fmt.pad(p.runeBuf[0:w]) p.fmt.pad(p.runeBuf[0:w])
} }
func (p *pp) fmtInt64(v int64, verb int, value interface{}) { func (p *pp) fmtInt64(v int64, verb int) {
switch verb { switch verb {
case 'b': case 'b':
p.fmt.integer(v, 2, signed, ldigits) p.fmt.integer(v, 2, signed, ldigits)
...@@ -347,7 +350,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) { ...@@ -347,7 +350,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
if 0 <= v && v <= unicode.MaxRune { if 0 <= v && v <= unicode.MaxRune {
p.fmt.fmt_qc(v) p.fmt.fmt_qc(v)
} else { } else {
p.badVerb(verb, value) p.badVerb(verb)
} }
case 'x': case 'x':
p.fmt.integer(v, 16, signed, ldigits) p.fmt.integer(v, 16, signed, ldigits)
...@@ -356,7 +359,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) { ...@@ -356,7 +359,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) {
case 'X': case 'X':
p.fmt.integer(v, 16, signed, udigits) p.fmt.integer(v, 16, signed, udigits)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
...@@ -391,7 +394,7 @@ func (p *pp) fmtUnicode(v int64) { ...@@ -391,7 +394,7 @@ func (p *pp) fmtUnicode(v int64) {
p.fmt.sharp = sharp p.fmt.sharp = sharp
} }
func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool) {
switch verb { switch verb {
case 'b': case 'b':
p.fmt.integer(int64(v), 2, unsigned, ldigits) p.fmt.integer(int64(v), 2, unsigned, ldigits)
...@@ -411,7 +414,7 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { ...@@ -411,7 +414,7 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
if 0 <= v && v <= unicode.MaxRune { if 0 <= v && v <= unicode.MaxRune {
p.fmt.fmt_qc(int64(v)) p.fmt.fmt_qc(int64(v))
} else { } else {
p.badVerb(verb, value) p.badVerb(verb)
} }
case 'x': case 'x':
p.fmt.integer(int64(v), 16, unsigned, ldigits) p.fmt.integer(int64(v), 16, unsigned, ldigits)
...@@ -420,11 +423,11 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { ...@@ -420,11 +423,11 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) {
case 'U': case 'U':
p.fmtUnicode(int64(v)) p.fmtUnicode(int64(v))
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { func (p *pp) fmtFloat32(v float32, verb int) {
switch verb { switch verb {
case 'b': case 'b':
p.fmt.fmt_fb32(v) p.fmt.fmt_fb32(v)
...@@ -439,11 +442,11 @@ func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { ...@@ -439,11 +442,11 @@ func (p *pp) fmtFloat32(v float32, verb int, value interface{}) {
case 'G': case 'G':
p.fmt.fmt_G32(v) p.fmt.fmt_G32(v)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { func (p *pp) fmtFloat64(v float64, verb int) {
switch verb { switch verb {
case 'b': case 'b':
p.fmt.fmt_fb64(v) p.fmt.fmt_fb64(v)
...@@ -458,33 +461,33 @@ func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { ...@@ -458,33 +461,33 @@ func (p *pp) fmtFloat64(v float64, verb int, value interface{}) {
case 'G': case 'G':
p.fmt.fmt_G64(v) p.fmt.fmt_G64(v)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) { func (p *pp) fmtComplex64(v complex64, verb int) {
switch verb { switch verb {
case 'e', 'E', 'f', 'F', 'g', 'G': case 'e', 'E', 'f', 'F', 'g', 'G':
p.fmt.fmt_c64(v, verb) p.fmt.fmt_c64(v, verb)
case 'v': case 'v':
p.fmt.fmt_c64(v, 'g') p.fmt.fmt_c64(v, 'g')
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) { func (p *pp) fmtComplex128(v complex128, verb int) {
switch verb { switch verb {
case 'e', 'E', 'f', 'F', 'g', 'G': case 'e', 'E', 'f', 'F', 'g', 'G':
p.fmt.fmt_c128(v, verb) p.fmt.fmt_c128(v, verb)
case 'v': case 'v':
p.fmt.fmt_c128(v, 'g') p.fmt.fmt_c128(v, 'g')
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { func (p *pp) fmtString(v string, verb int, goSyntax bool) {
switch verb { switch verb {
case 'v': case 'v':
if goSyntax { if goSyntax {
...@@ -501,11 +504,11 @@ func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { ...@@ -501,11 +504,11 @@ func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) {
case 'q': case 'q':
p.fmt.fmt_q(v) p.fmt.fmt_q(v)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) { func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int) {
if verb == 'v' || verb == 'd' { if verb == 'v' || verb == 'd' {
if goSyntax { if goSyntax {
p.buf.Write(bytesBytes) p.buf.Write(bytesBytes)
...@@ -540,17 +543,17 @@ func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interf ...@@ -540,17 +543,17 @@ func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interf
case 'q': case 'q':
p.fmt.fmt_q(s) p.fmt.fmt_q(s)
default: default:
p.badVerb(verb, value) p.badVerb(verb)
} }
} }
func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) { func (p *pp) fmtPointer(value reflect.Value, verb int, goSyntax bool) {
var u uintptr var u uintptr
switch value.Kind() { switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
u = value.Pointer() u = value.Pointer()
default: default:
p.badVerb(verb, field) p.badVerb(verb)
return return
} }
if goSyntax { if goSyntax {
...@@ -576,12 +579,12 @@ var ( ...@@ -576,12 +579,12 @@ var (
uintptrBits = reflect.TypeOf(uintptr(0)).Bits() uintptrBits = reflect.TypeOf(uintptr(0)).Bits()
) )
func (p *pp) catchPanic(val interface{}, verb int) { func (p *pp) catchPanic(field interface{}, verb int) {
if err := recover(); err != nil { if err := recover(); err != nil {
// If it's a nil pointer, just say "<nil>". The likeliest causes are a // If it's a nil pointer, just say "<nil>". The likeliest causes are a
// Stringer that fails to guard against nil or a nil pointer for a // Stringer that fails to guard against nil or a nil pointer for a
// value receiver, and in either case, "<nil>" is a nice result. // value receiver, and in either case, "<nil>" is a nice result.
if v := reflect.ValueOf(val); v.Kind() == reflect.Ptr && v.IsNil() { if v := reflect.ValueOf(field); v.Kind() == reflect.Ptr && v.IsNil() {
p.buf.Write(nilAngleBytes) p.buf.Write(nilAngleBytes)
return return
} }
...@@ -601,12 +604,12 @@ func (p *pp) catchPanic(val interface{}, verb int) { ...@@ -601,12 +604,12 @@ func (p *pp) catchPanic(val interface{}, verb int) {
} }
} }
func (p *pp) handleMethods(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString, handled bool) { func (p *pp) handleMethods(verb int, plus, goSyntax bool, depth int) (wasString, handled bool) {
// Is it a Formatter? // Is it a Formatter?
if formatter, ok := field.(Formatter); ok { if formatter, ok := p.field.(Formatter); ok {
handled = true handled = true
wasString = false wasString = false
defer p.catchPanic(field, verb) defer p.catchPanic(p.field, verb)
formatter.Format(p, verb) formatter.Format(p, verb)
return return
} }
...@@ -618,20 +621,20 @@ func (p *pp) handleMethods(field interface{}, verb int, plus, goSyntax bool, dep ...@@ -618,20 +621,20 @@ func (p *pp) handleMethods(field interface{}, verb int, plus, goSyntax bool, dep
// If we're doing Go syntax and the field knows how to supply it, take care of it now. // If we're doing Go syntax and the field knows how to supply it, take care of it now.
if goSyntax { if goSyntax {
p.fmt.sharp = false p.fmt.sharp = false
if stringer, ok := field.(GoStringer); ok { if stringer, ok := p.field.(GoStringer); ok {
wasString = false wasString = false
handled = true handled = true
defer p.catchPanic(field, verb) defer p.catchPanic(p.field, verb)
// Print the result of GoString unadorned. // Print the result of GoString unadorned.
p.fmtString(stringer.GoString(), 's', false, field) p.fmtString(stringer.GoString(), 's', false)
return return
} }
} else { } else {
// Is it a Stringer? // Is it a Stringer?
if stringer, ok := field.(Stringer); ok { if stringer, ok := p.field.(Stringer); ok {
wasString = false wasString = false
handled = true handled = true
defer p.catchPanic(field, verb) defer p.catchPanic(p.field, verb)
p.printField(stringer.String(), verb, plus, false, depth) p.printField(stringer.String(), verb, plus, false, depth)
return return
} }
...@@ -645,11 +648,13 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth ...@@ -645,11 +648,13 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth
if verb == 'T' || verb == 'v' { if verb == 'T' || verb == 'v' {
p.buf.Write(nilAngleBytes) p.buf.Write(nilAngleBytes)
} else { } else {
p.badVerb(verb, field) p.badVerb(verb)
} }
return false return false
} }
p.field = field
p.value = reflect.Value{}
// Special processing considerations. // Special processing considerations.
// %T (the value's type) and %p (its address) are special; we always do them first. // %T (the value's type) and %p (its address) are special; we always do them first.
switch verb { switch verb {
...@@ -657,74 +662,60 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth ...@@ -657,74 +662,60 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth
p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) p.printField(reflect.TypeOf(field).String(), 's', false, false, 0)
return false return false
case 'p': case 'p':
p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) p.fmtPointer(reflect.ValueOf(field), verb, goSyntax)
return false return false
} }
if wasString, handled := p.handleMethods(field, verb, plus, goSyntax, depth); handled { if wasString, handled := p.handleMethods(verb, plus, goSyntax, depth); handled {
return wasString return wasString
} }
// Some types can be done without reflection. // Some types can be done without reflection.
switch f := field.(type) { switch f := field.(type) {
case bool: case bool:
p.fmtBool(f, verb, field) p.fmtBool(f, verb)
return false
case float32: case float32:
p.fmtFloat32(f, verb, field) p.fmtFloat32(f, verb)
return false
case float64: case float64:
p.fmtFloat64(f, verb, field) p.fmtFloat64(f, verb)
return false
case complex64: case complex64:
p.fmtComplex64(complex64(f), verb, field) p.fmtComplex64(complex64(f), verb)
return false
case complex128: case complex128:
p.fmtComplex128(f, verb, field) p.fmtComplex128(f, verb)
return false
case int: case int:
p.fmtInt64(int64(f), verb, field) p.fmtInt64(int64(f), verb)
return false
case int8: case int8:
p.fmtInt64(int64(f), verb, field) p.fmtInt64(int64(f), verb)
return false
case int16: case int16:
p.fmtInt64(int64(f), verb, field) p.fmtInt64(int64(f), verb)
return false
case int32: case int32:
p.fmtInt64(int64(f), verb, field) p.fmtInt64(int64(f), verb)
return false
case int64: case int64:
p.fmtInt64(f, verb, field) p.fmtInt64(f, verb)
return false
case uint: case uint:
p.fmtUint64(uint64(f), verb, goSyntax, field) p.fmtUint64(uint64(f), verb, goSyntax)
return false
case uint8: case uint8:
p.fmtUint64(uint64(f), verb, goSyntax, field) p.fmtUint64(uint64(f), verb, goSyntax)
return false
case uint16: case uint16:
p.fmtUint64(uint64(f), verb, goSyntax, field) p.fmtUint64(uint64(f), verb, goSyntax)
return false
case uint32: case uint32:
p.fmtUint64(uint64(f), verb, goSyntax, field) p.fmtUint64(uint64(f), verb, goSyntax)
return false
case uint64: case uint64:
p.fmtUint64(f, verb, goSyntax, field) p.fmtUint64(f, verb, goSyntax)
return false
case uintptr: case uintptr:
p.fmtUint64(uint64(f), verb, goSyntax, field) p.fmtUint64(uint64(f), verb, goSyntax)
return false
case string: case string:
p.fmtString(f, verb, goSyntax, field) p.fmtString(f, verb, goSyntax)
return verb == 's' || verb == 'v' wasString = verb == 's' || verb == 'v'
case []byte: case []byte:
p.fmtBytes(f, verb, goSyntax, depth, field) p.fmtBytes(f, verb, goSyntax, depth)
return verb == 's' wasString = verb == 's'
} default:
// Need to use reflection // Need to use reflection
return p.printReflectValue(reflect.ValueOf(field), verb, plus, goSyntax, depth) return p.printReflectValue(reflect.ValueOf(field), verb, plus, goSyntax, depth)
}
p.field = nil
return
} }
// printValue is like printField but starts with a reflect value, not an interface{} value. // printValue is like printField but starts with a reflect value, not an interface{} value.
...@@ -733,7 +724,7 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept ...@@ -733,7 +724,7 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept
if verb == 'T' || verb == 'v' { if verb == 'T' || verb == 'v' {
p.buf.Write(nilAngleBytes) p.buf.Write(nilAngleBytes)
} else { } else {
p.badVerb(verb, nil) p.badVerb(verb)
} }
return false return false
} }
...@@ -745,17 +736,17 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept ...@@ -745,17 +736,17 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept
p.printField(value.Type().String(), 's', false, false, 0) p.printField(value.Type().String(), 's', false, false, 0)
return false return false
case 'p': case 'p':
p.fmtPointer(nil, value, verb, goSyntax) p.fmtPointer(value, verb, goSyntax)
return false return false
} }
// Handle values with special methods. // Handle values with special methods.
// Call always, even when field == nil, because handleMethods clears p.fmt.plus for us. // Call always, even when field == nil, because handleMethods clears p.fmt.plus for us.
var field interface{} p.field = nil // Make sure it's cleared, for safety.
if value.CanInterface() { if value.CanInterface() {
field = value.Interface() p.field = value.Interface()
} }
if wasString, handled := p.handleMethods(field, verb, plus, goSyntax, depth); handled { if wasString, handled := p.handleMethods(verb, plus, goSyntax, depth); handled {
return wasString return wasString
} }
...@@ -770,25 +761,25 @@ func (p *pp) printReflectValue(value reflect.Value, verb int, plus, goSyntax boo ...@@ -770,25 +761,25 @@ func (p *pp) printReflectValue(value reflect.Value, verb int, plus, goSyntax boo
BigSwitch: BigSwitch:
switch f := value; f.Kind() { switch f := value; f.Kind() {
case reflect.Bool: case reflect.Bool:
p.fmtBool(f.Bool(), verb, nil) p.fmtBool(f.Bool(), verb)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
p.fmtInt64(f.Int(), verb, nil) p.fmtInt64(f.Int(), verb)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
p.fmtUint64(uint64(f.Uint()), verb, goSyntax, nil) p.fmtUint64(uint64(f.Uint()), verb, goSyntax)
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
if f.Type().Size() == 4 { if f.Type().Size() == 4 {
p.fmtFloat32(float32(f.Float()), verb, nil) p.fmtFloat32(float32(f.Float()), verb)
} else { } else {
p.fmtFloat64(float64(f.Float()), verb, nil) p.fmtFloat64(float64(f.Float()), verb)
} }
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
if f.Type().Size() == 8 { if f.Type().Size() == 8 {
p.fmtComplex64(complex64(f.Complex()), verb, nil) p.fmtComplex64(complex64(f.Complex()), verb)
} else { } else {
p.fmtComplex128(complex128(f.Complex()), verb, nil) p.fmtComplex128(complex128(f.Complex()), verb)
} }
case reflect.String: case reflect.String:
p.fmtString(f.String(), verb, goSyntax, nil) p.fmtString(f.String(), verb, goSyntax)
case reflect.Map: case reflect.Map:
if goSyntax { if goSyntax {
p.buf.WriteString(f.Type().String()) p.buf.WriteString(f.Type().String())
...@@ -842,7 +833,7 @@ BigSwitch: ...@@ -842,7 +833,7 @@ BigSwitch:
value := f.Elem() value := f.Elem()
if !value.IsValid() { if !value.IsValid() {
if goSyntax { if goSyntax {
p.buf.WriteString(value.Type().String()) p.buf.WriteString(f.Type().String())
p.buf.Write(nilParenBytes) p.buf.Write(nilParenBytes)
} else { } else {
p.buf.Write(nilAngleBytes) p.buf.Write(nilAngleBytes)
...@@ -864,7 +855,7 @@ BigSwitch: ...@@ -864,7 +855,7 @@ BigSwitch:
for i := range bytes { for i := range bytes {
bytes[i] = byte(f.Index(i).Uint()) bytes[i] = byte(f.Index(i).Uint())
} }
p.fmtBytes(bytes, verb, goSyntax, depth, nil) p.fmtBytes(bytes, verb, goSyntax, depth)
wasString = verb == 's' wasString = verb == 's'
break break
} }
...@@ -924,7 +915,7 @@ BigSwitch: ...@@ -924,7 +915,7 @@ BigSwitch:
} }
p.fmt0x64(uint64(v), true) p.fmt0x64(uint64(v), true)
case reflect.Chan, reflect.Func, reflect.UnsafePointer: case reflect.Chan, reflect.Func, reflect.UnsafePointer:
p.fmtPointer(nil, value, verb, goSyntax) p.fmtPointer(value, verb, goSyntax)
default: default:
p.unknownType(f) p.unknownType(f)
} }
......
...@@ -23,11 +23,10 @@ var tests = []struct { ...@@ -23,11 +23,10 @@ var tests = []struct {
{"foobar", "0 \"foobar\""}, {"foobar", "0 \"foobar\""},
// maps // maps
{map[string]int{"a": 1, "b": 2}, {map[string]int{"a": 1},
`0 map[string] int (len = 2) { `0 map[string] int (len = 1) {
1 . "a": 1 1 . "a": 1
2 . "b": 2 2 }`},
3 }`},
// pointers // pointers
{new(int), "0 *0"}, {new(int), "0 *0"},
......
...@@ -70,9 +70,6 @@ call to Next. For example, to extract an HTML page's anchor text: ...@@ -70,9 +70,6 @@ call to Next. For example, to extract an HTML page's anchor text:
} }
} }
A Tokenizer typically skips over HTML comments. To return comment tokens, set
Tokenizer.ReturnComments to true before looping over calls to Next.
Parsing is done by calling Parse with an io.Reader, which returns the root of Parsing is done by calling Parse with an io.Reader, which returns the root of
the parse tree (the document element) as a *Node. It is the caller's the parse tree (the document element) as a *Node. It is the caller's
responsibility to ensure that the Reader provides UTF-8 encoded HTML. For responsibility to ensure that the Reader provides UTF-8 encoded HTML. For
......
...@@ -32,6 +32,9 @@ type parser struct { ...@@ -32,6 +32,9 @@ type parser struct {
// originalIM is the insertion mode to go back to after completing a text // originalIM is the insertion mode to go back to after completing a text
// or inTableText insertion mode. // or inTableText insertion mode.
originalIM insertionMode originalIM insertionMode
// fosterParenting is whether new elements should be inserted according to
// the foster parenting rules (section 11.2.5.3).
fosterParenting bool
} }
func (p *parser) top() *Node { func (p *parser) top() *Node {
...@@ -49,6 +52,11 @@ var ( ...@@ -49,6 +52,11 @@ var (
tableScopeStopTags = []string{"html", "table"} tableScopeStopTags = []string{"html", "table"}
) )
// stopTags for use in clearStackToContext.
var (
tableRowContextStopTags = []string{"tr", "html"}
)
// popUntil pops the stack of open elements at the highest element whose tag // popUntil pops the stack of open elements at the highest element whose tag
// is in matchTags, provided there is no higher element in stopTags. It returns // is in matchTags, provided there is no higher element in stopTags. It returns
// whether or not there was such an element. If there was not, popUntil leaves // whether or not there was such an element. If there was not, popUntil leaves
...@@ -103,12 +111,61 @@ func (p *parser) elementInScope(stopTags []string, matchTags ...string) bool { ...@@ -103,12 +111,61 @@ func (p *parser) elementInScope(stopTags []string, matchTags ...string) bool {
// addChild adds a child node n to the top element, and pushes n onto the stack // addChild adds a child node n to the top element, and pushes n onto the stack
// of open elements if it is an element node. // of open elements if it is an element node.
func (p *parser) addChild(n *Node) { func (p *parser) addChild(n *Node) {
if p.fosterParenting {
p.fosterParent(n)
} else {
p.top().Add(n) p.top().Add(n)
}
if n.Type == ElementNode { if n.Type == ElementNode {
p.oe = append(p.oe, n) p.oe = append(p.oe, n)
} }
} }
// fosterParent adds a child node according to the foster parenting rules.
// Section 11.2.5.3, "foster parenting".
func (p *parser) fosterParent(n *Node) {
var table, parent *Node
var i int
for i = len(p.oe) - 1; i >= 0; i-- {
if p.oe[i].Data == "table" {
table = p.oe[i]
break
}
}
if table == nil {
// The foster parent is the html element.
parent = p.oe[0]
} else {
parent = table.Parent
}
if parent == nil {
parent = p.oe[i-1]
}
var child *Node
for i, child = range parent.Child {
if child == table {
break
}
}
if i > 0 && parent.Child[i-1].Type == TextNode && n.Type == TextNode {
parent.Child[i-1].Data += n.Data
return
}
if i == len(parent.Child) {
parent.Add(n)
} else {
// Insert n into parent.Child at index i.
parent.Child = append(parent.Child[:i+1], parent.Child[i:]...)
parent.Child[i] = n
n.Parent = parent
}
}
// addText adds text to the preceding node if it is a text node, or else it // addText adds text to the preceding node if it is a text node, or else it
// calls addChild with a new text node. // calls addChild with a new text node.
func (p *parser) addText(text string) { func (p *parser) addText(text string) {
...@@ -170,9 +227,9 @@ func (p *parser) reconstructActiveFormattingElements() { ...@@ -170,9 +227,9 @@ func (p *parser) reconstructActiveFormattingElements() {
} }
for { for {
i++ i++
n = p.afe[i] clone := p.afe[i].clone()
p.addChild(n.clone()) p.addChild(clone)
p.afe[i] = n p.afe[i] = clone
if i == len(p.afe)-1 { if i == len(p.afe)-1 {
break break
} }
...@@ -234,10 +291,52 @@ func (p *parser) setOriginalIM(im insertionMode) { ...@@ -234,10 +291,52 @@ func (p *parser) setOriginalIM(im insertionMode) {
p.originalIM = im p.originalIM = im
} }
// Section 11.2.3.1, "reset the insertion mode".
func (p *parser) resetInsertionMode() insertionMode {
for i := len(p.oe) - 1; i >= 0; i-- {
n := p.oe[i]
if i == 0 {
// TODO: set n to the context element, for HTML fragment parsing.
}
switch n.Data {
case "select":
return inSelectIM
case "td", "th":
return inCellIM
case "tr":
return inRowIM
case "tbody", "thead", "tfoot":
return inTableBodyIM
case "caption":
// TODO: return inCaptionIM
case "colgroup":
// TODO: return inColumnGroupIM
case "table":
return inTableIM
case "head":
return inBodyIM
case "body":
return inBodyIM
case "frameset":
// TODO: return inFramesetIM
case "html":
return beforeHeadIM
}
}
return inBodyIM
}
// Section 11.2.5.4.1. // Section 11.2.5.4.1.
func initialIM(p *parser) (insertionMode, bool) { func initialIM(p *parser) (insertionMode, bool) {
if p.tok.Type == DoctypeToken { switch p.tok.Type {
p.addChild(&Node{ case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return initialIM, true
case DoctypeToken:
p.doc.Add(&Node{
Type: DoctypeNode, Type: DoctypeNode,
Data: p.tok.Data, Data: p.tok.Data,
}) })
...@@ -275,6 +374,12 @@ func beforeHTMLIM(p *parser) (insertionMode, bool) { ...@@ -275,6 +374,12 @@ func beforeHTMLIM(p *parser) (insertionMode, bool) {
default: default:
// Ignore the token. // Ignore the token.
} }
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return beforeHTMLIM, true
} }
if add || implied { if add || implied {
p.addElement("html", attr) p.addElement("html", attr)
...@@ -312,6 +417,12 @@ func beforeHeadIM(p *parser) (insertionMode, bool) { ...@@ -312,6 +417,12 @@ func beforeHeadIM(p *parser) (insertionMode, bool) {
default: default:
// Ignore the token. // Ignore the token.
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return beforeHeadIM, true
} }
if add || implied { if add || implied {
p.addElement("head", attr) p.addElement("head", attr)
...@@ -344,11 +455,17 @@ func inHeadIM(p *parser) (insertionMode, bool) { ...@@ -344,11 +455,17 @@ func inHeadIM(p *parser) (insertionMode, bool) {
pop = true pop = true
} }
// TODO. // TODO.
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return inHeadIM, true
} }
if pop || implied { if pop || implied {
n := p.oe.pop() n := p.oe.pop()
if n.Data != "head" { if n.Data != "head" {
panic("html: bad parser state") panic("html: bad parser state: <head> element not found, in the in-head insertion mode")
} }
return afterHeadIM, !implied return afterHeadIM, !implied
} }
...@@ -387,6 +504,12 @@ func afterHeadIM(p *parser) (insertionMode, bool) { ...@@ -387,6 +504,12 @@ func afterHeadIM(p *parser) (insertionMode, bool) {
} }
case EndTagToken: case EndTagToken:
// TODO. // TODO.
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return afterHeadIM, true
} }
if add || implied { if add || implied {
p.addElement("body", attr) p.addElement("body", attr)
...@@ -447,6 +570,30 @@ func inBodyIM(p *parser) (insertionMode, bool) { ...@@ -447,6 +570,30 @@ func inBodyIM(p *parser) (insertionMode, bool) {
p.oe.pop() p.oe.pop()
p.acknowledgeSelfClosingTag() p.acknowledgeSelfClosingTag()
p.framesetOK = false p.framesetOK = false
case "select":
p.reconstructActiveFormattingElements()
p.addElement(p.tok.Data, p.tok.Attr)
p.framesetOK = false
// TODO: detect <select> inside a table.
return inSelectIM, true
case "li":
p.framesetOK = false
for i := len(p.oe) - 1; i >= 0; i-- {
node := p.oe[i]
switch node.Data {
case "li":
p.popUntil(listItemScopeStopTags, "li")
case "address", "div", "p":
continue
default:
if !isSpecialElement[node.Data] {
continue
}
}
break
}
p.popUntil(buttonScopeStopTags, "p")
p.addElement("li", p.tok.Attr)
default: default:
// TODO. // TODO.
p.addElement(p.tok.Data, p.tok.Attr) p.addElement(p.tok.Data, p.tok.Attr)
...@@ -463,12 +610,16 @@ func inBodyIM(p *parser) (insertionMode, bool) { ...@@ -463,12 +610,16 @@ func inBodyIM(p *parser) (insertionMode, bool) {
p.popUntil(buttonScopeStopTags, "p") p.popUntil(buttonScopeStopTags, "p")
case "a", "b", "big", "code", "em", "font", "i", "nobr", "s", "small", "strike", "strong", "tt", "u": case "a", "b", "big", "code", "em", "font", "i", "nobr", "s", "small", "strike", "strong", "tt", "u":
p.inBodyEndTagFormatting(p.tok.Data) p.inBodyEndTagFormatting(p.tok.Data)
case "address", "article", "aside", "blockquote", "button", "center", "details", "dir", "div", "dl", "fieldset", "figcaption", "figure", "footer", "header", "hgroup", "listing", "menu", "nav", "ol", "pre", "section", "summary", "ul":
p.popUntil(defaultScopeStopTags, p.tok.Data)
default: default:
// TODO: any other end tag p.inBodyEndTagOther(p.tok.Data)
if p.tok.Data == p.top().Data {
p.oe.pop()
}
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
} }
return inBodyIM, true return inBodyIM, true
...@@ -496,6 +647,7 @@ func (p *parser) inBodyEndTagFormatting(tag string) { ...@@ -496,6 +647,7 @@ func (p *parser) inBodyEndTagFormatting(tag string) {
} }
} }
if formattingElement == nil { if formattingElement == nil {
p.inBodyEndTagOther(tag)
return return
} }
feIndex := p.oe.index(formattingElement) feIndex := p.oe.index(formattingElement)
...@@ -568,8 +720,7 @@ func (p *parser) inBodyEndTagFormatting(tag string) { ...@@ -568,8 +720,7 @@ func (p *parser) inBodyEndTagFormatting(tag string) {
} }
switch commonAncestor.Data { switch commonAncestor.Data {
case "table", "tbody", "tfoot", "thead", "tr": case "table", "tbody", "tfoot", "thead", "tr":
// TODO: fix up misnested table nodes; find the foster parent. p.fosterParent(lastNode)
fallthrough
default: default:
commonAncestor.Add(lastNode) commonAncestor.Add(lastNode)
} }
...@@ -590,6 +741,19 @@ func (p *parser) inBodyEndTagFormatting(tag string) { ...@@ -590,6 +741,19 @@ func (p *parser) inBodyEndTagFormatting(tag string) {
} }
} }
// inBodyEndTagOther performs the "any other end tag" algorithm for inBodyIM.
func (p *parser) inBodyEndTagOther(tag string) {
for i := len(p.oe) - 1; i >= 0; i-- {
if p.oe[i].Data == tag {
p.oe = p.oe[:i]
break
}
if isSpecialElement[p.oe[i].Data] {
break
}
}
}
// Section 11.2.5.4.8. // Section 11.2.5.4.8.
func textIM(p *parser) (insertionMode, bool) { func textIM(p *parser) (insertionMode, bool) {
switch p.tok.Type { switch p.tok.Type {
...@@ -606,12 +770,6 @@ func textIM(p *parser) (insertionMode, bool) { ...@@ -606,12 +770,6 @@ func textIM(p *parser) (insertionMode, bool) {
// Section 11.2.5.4.9. // Section 11.2.5.4.9.
func inTableIM(p *parser) (insertionMode, bool) { func inTableIM(p *parser) (insertionMode, bool) {
var (
add bool
data string
attr []Attribute
consumed bool
)
switch p.tok.Type { switch p.tok.Type {
case ErrorToken: case ErrorToken:
// Stop parsing. // Stop parsing.
...@@ -621,13 +779,19 @@ func inTableIM(p *parser) (insertionMode, bool) { ...@@ -621,13 +779,19 @@ func inTableIM(p *parser) (insertionMode, bool) {
case StartTagToken: case StartTagToken:
switch p.tok.Data { switch p.tok.Data {
case "tbody", "tfoot", "thead": case "tbody", "tfoot", "thead":
add = true p.clearStackToContext(tableScopeStopTags)
data = p.tok.Data p.addElement(p.tok.Data, p.tok.Attr)
attr = p.tok.Attr return inTableBodyIM, true
consumed = true
case "td", "th", "tr": case "td", "th", "tr":
add = true p.clearStackToContext(tableScopeStopTags)
data = "tbody" p.addElement("tbody", nil)
return inTableBodyIM, false
case "table":
if p.popUntil(tableScopeStopTags, "table") {
return p.resetInsertionMode(), false
}
// Ignore the token.
return inTableIM, true
default: default:
// TODO. // TODO.
} }
...@@ -635,8 +799,7 @@ func inTableIM(p *parser) (insertionMode, bool) { ...@@ -635,8 +799,7 @@ func inTableIM(p *parser) (insertionMode, bool) {
switch p.tok.Data { switch p.tok.Data {
case "table": case "table":
if p.popUntil(tableScopeStopTags, "table") { if p.popUntil(tableScopeStopTags, "table") {
// TODO: "reset the insertion mode appropriately" as per 11.2.3.1. return p.resetInsertionMode(), true
return inBodyIM, false
} }
// Ignore the token. // Ignore the token.
return inTableIM, true return inTableIM, true
...@@ -644,14 +807,34 @@ func inTableIM(p *parser) (insertionMode, bool) { ...@@ -644,14 +807,34 @@ func inTableIM(p *parser) (insertionMode, bool) {
// Ignore the token. // Ignore the token.
return inTableIM, true return inTableIM, true
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return inTableIM, true
}
switch p.top().Data {
case "table", "tbody", "tfoot", "thead", "tr":
p.fosterParenting = true
defer func() { p.fosterParenting = false }()
}
return useTheRulesFor(p, inTableIM, inBodyIM)
}
// clearStackToContext pops elements off the stack of open elements
// until an element listed in stopTags is found.
func (p *parser) clearStackToContext(stopTags []string) {
for i := len(p.oe) - 1; i >= 0; i-- {
for _, tag := range stopTags {
if p.oe[i].Data == tag {
p.oe = p.oe[:i+1]
return
}
} }
if add {
// TODO: clear the stack back to a table context.
p.addElement(data, attr)
return inTableBodyIM, consumed
} }
// TODO: return useTheRulesFor(inTableIM, inBodyIM, p) unless etc. etc. foster parenting.
return inTableIM, true
} }
// Section 11.2.5.4.13. // Section 11.2.5.4.13.
...@@ -693,6 +876,12 @@ func inTableBodyIM(p *parser) (insertionMode, bool) { ...@@ -693,6 +876,12 @@ func inTableBodyIM(p *parser) (insertionMode, bool) {
// Ignore the token. // Ignore the token.
return inTableBodyIM, true return inTableBodyIM, true
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return inTableBodyIM, true
} }
if add { if add {
// TODO: clear the stack back to a table body context. // TODO: clear the stack back to a table body context.
...@@ -722,7 +911,12 @@ func inRowIM(p *parser) (insertionMode, bool) { ...@@ -722,7 +911,12 @@ func inRowIM(p *parser) (insertionMode, bool) {
case EndTagToken: case EndTagToken:
switch p.tok.Data { switch p.tok.Data {
case "tr": case "tr":
// TODO. if !p.elementInScope(tableScopeStopTags, "tr") {
return inRowIM, true
}
p.clearStackToContext(tableRowContextStopTags)
p.oe.pop()
return inTableBodyIM, true
case "table": case "table":
if p.popUntil(tableScopeStopTags, "tr") { if p.popUntil(tableScopeStopTags, "tr") {
return inTableBodyIM, false return inTableBodyIM, false
...@@ -737,6 +931,12 @@ func inRowIM(p *parser) (insertionMode, bool) { ...@@ -737,6 +931,12 @@ func inRowIM(p *parser) (insertionMode, bool) {
default: default:
// TODO. // TODO.
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return inRowIM, true
} }
return useTheRulesFor(p, inRowIM, inTableIM) return useTheRulesFor(p, inRowIM, inTableIM)
} }
...@@ -763,6 +963,12 @@ func inCellIM(p *parser) (insertionMode, bool) { ...@@ -763,6 +963,12 @@ func inCellIM(p *parser) (insertionMode, bool) {
// TODO: check for matching element in table scope. // TODO: check for matching element in table scope.
closeTheCellAndReprocess = true closeTheCellAndReprocess = true
} }
case CommentToken:
p.addChild(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return inCellIM, true
} }
if closeTheCellAndReprocess { if closeTheCellAndReprocess {
if p.popUntil(tableScopeStopTags, "td") || p.popUntil(tableScopeStopTags, "th") { if p.popUntil(tableScopeStopTags, "td") || p.popUntil(tableScopeStopTags, "th") {
...@@ -773,6 +979,68 @@ func inCellIM(p *parser) (insertionMode, bool) { ...@@ -773,6 +979,68 @@ func inCellIM(p *parser) (insertionMode, bool) {
return useTheRulesFor(p, inCellIM, inBodyIM) return useTheRulesFor(p, inCellIM, inBodyIM)
} }
// Section 11.2.5.4.16.
func inSelectIM(p *parser) (insertionMode, bool) {
endSelect := false
switch p.tok.Type {
case ErrorToken:
// TODO.
case TextToken:
p.addText(p.tok.Data)
case StartTagToken:
switch p.tok.Data {
case "html":
// TODO.
case "option":
if p.top().Data == "option" {
p.oe.pop()
}
p.addElement(p.tok.Data, p.tok.Attr)
case "optgroup":
// TODO.
case "select":
endSelect = true
case "input", "keygen", "textarea":
// TODO.
case "script":
// TODO.
default:
// Ignore the token.
}
case EndTagToken:
switch p.tok.Data {
case "option":
// TODO.
case "optgroup":
// TODO.
case "select":
endSelect = true
default:
// Ignore the token.
}
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
}
if endSelect {
for i := len(p.oe) - 1; i >= 0; i-- {
switch p.oe[i].Data {
case "select":
p.oe = p.oe[:i]
return p.resetInsertionMode(), true
case "option", "optgroup":
continue
default:
// Ignore the token.
return inSelectIM, true
}
}
}
return inSelectIM, true
}
// Section 11.2.5.4.18. // Section 11.2.5.4.18.
func afterBodyIM(p *parser) (insertionMode, bool) { func afterBodyIM(p *parser) (insertionMode, bool) {
switch p.tok.Type { switch p.tok.Type {
...@@ -790,7 +1058,18 @@ func afterBodyIM(p *parser) (insertionMode, bool) { ...@@ -790,7 +1058,18 @@ func afterBodyIM(p *parser) (insertionMode, bool) {
default: default:
// TODO. // TODO.
} }
case CommentToken:
// The comment is attached to the <html> element.
if len(p.oe) < 1 || p.oe[0].Data != "html" {
panic("html: bad parser state: <html> element not found, in the after-body insertion mode")
} }
p.oe[0].Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return afterBodyIM, true
}
// TODO: should this be "return inBodyIM, true"?
return afterBodyIM, true return afterBodyIM, true
} }
...@@ -806,6 +1085,12 @@ func afterAfterBodyIM(p *parser) (insertionMode, bool) { ...@@ -806,6 +1085,12 @@ func afterAfterBodyIM(p *parser) (insertionMode, bool) {
if p.tok.Data == "html" { if p.tok.Data == "html" {
return useTheRulesFor(p, afterAfterBodyIM, inBodyIM) return useTheRulesFor(p, afterAfterBodyIM, inBodyIM)
} }
case CommentToken:
p.doc.Add(&Node{
Type: CommentNode,
Data: p.tok.Data,
})
return afterAfterBodyIM, true
} }
return inBodyIM, false return inBodyIM, false
} }
......
...@@ -69,11 +69,15 @@ func readDat(filename string, c chan io.Reader) { ...@@ -69,11 +69,15 @@ func readDat(filename string, c chan io.Reader) {
} }
} }
func dumpLevel(w io.Writer, n *Node, level int) os.Error { func dumpIndent(w io.Writer, level int) {
io.WriteString(w, "| ") io.WriteString(w, "| ")
for i := 0; i < level; i++ { for i := 0; i < level; i++ {
io.WriteString(w, " ") io.WriteString(w, " ")
} }
}
func dumpLevel(w io.Writer, n *Node, level int) os.Error {
dumpIndent(w, level)
switch n.Type { switch n.Type {
case ErrorNode: case ErrorNode:
return os.NewError("unexpected ErrorNode") return os.NewError("unexpected ErrorNode")
...@@ -81,10 +85,15 @@ func dumpLevel(w io.Writer, n *Node, level int) os.Error { ...@@ -81,10 +85,15 @@ func dumpLevel(w io.Writer, n *Node, level int) os.Error {
return os.NewError("unexpected DocumentNode") return os.NewError("unexpected DocumentNode")
case ElementNode: case ElementNode:
fmt.Fprintf(w, "<%s>", n.Data) fmt.Fprintf(w, "<%s>", n.Data)
for _, a := range n.Attr {
io.WriteString(w, "\n")
dumpIndent(w, level+1)
fmt.Fprintf(w, `%s="%s"`, a.Key, a.Val)
}
case TextNode: case TextNode:
fmt.Fprintf(w, "%q", n.Data) fmt.Fprintf(w, "%q", n.Data)
case CommentNode: case CommentNode:
return os.NewError("COMMENT") fmt.Fprintf(w, "<!-- %s -->", n.Data)
case DoctypeNode: case DoctypeNode:
fmt.Fprintf(w, "<!DOCTYPE %s>", n.Data) fmt.Fprintf(w, "<!DOCTYPE %s>", n.Data)
case scopeMarkerNode: case scopeMarkerNode:
...@@ -123,7 +132,7 @@ func TestParser(t *testing.T) { ...@@ -123,7 +132,7 @@ func TestParser(t *testing.T) {
rc := make(chan io.Reader) rc := make(chan io.Reader)
go readDat(filename, rc) go readDat(filename, rc)
// TODO(nigeltao): Process all test cases, not just a subset. // TODO(nigeltao): Process all test cases, not just a subset.
for i := 0; i < 27; i++ { for i := 0; i < 34; i++ {
// Parse the #data section. // Parse the #data section.
b, err := ioutil.ReadAll(<-rc) b, err := ioutil.ReadAll(<-rc)
if err != nil { if err != nil {
...@@ -152,6 +161,13 @@ func TestParser(t *testing.T) { ...@@ -152,6 +161,13 @@ func TestParser(t *testing.T) {
continue continue
} }
// Check that rendering and re-parsing results in an identical tree. // Check that rendering and re-parsing results in an identical tree.
if filename == "tests1.dat" && i == 30 {
// Test 30 in tests1.dat is such messed-up markup that a correct parse
// results in a non-conforming tree (one <a> element nested inside another).
// Therefore when it is rendered and re-parsed, it isn't the same.
// So we skip rendering on that test.
continue
}
pr, pw := io.Pipe() pr, pw := io.Pipe()
go func() { go func() {
pw.CloseWithError(Render(pw, doc)) pw.CloseWithError(Render(pw, doc))
......
...@@ -30,9 +30,6 @@ type writer interface { ...@@ -30,9 +30,6 @@ type writer interface {
// would become a tree containing <html>, <head> and <body> elements. Another // would become a tree containing <html>, <head> and <body> elements. Another
// example is that the programmatic equivalent of "a<head>b</head>c" becomes // example is that the programmatic equivalent of "a<head>b</head>c" becomes
// "<html><head><head/><body>abc</body></html>". // "<html><head><head/><body>abc</body></html>".
//
// Comment nodes are elided from the output, analogous to Parse skipping over
// any <!--comment--> input.
func Render(w io.Writer, n *Node) os.Error { func Render(w io.Writer, n *Node) os.Error {
if x, ok := w.(writer); ok { if x, ok := w.(writer); ok {
return render(x, n) return render(x, n)
...@@ -61,6 +58,15 @@ func render(w writer, n *Node) os.Error { ...@@ -61,6 +58,15 @@ func render(w writer, n *Node) os.Error {
case ElementNode: case ElementNode:
// No-op. // No-op.
case CommentNode: case CommentNode:
if _, err := w.WriteString("<!--"); err != nil {
return err
}
if _, err := w.WriteString(n.Data); err != nil {
return err
}
if _, err := w.WriteString("-->"); err != nil {
return err
}
return nil return nil
case DoctypeNode: case DoctypeNode:
if _, err := w.WriteString("<!DOCTYPE "); err != nil { if _, err := w.WriteString("<!DOCTYPE "); err != nil {
......
...@@ -116,10 +116,6 @@ type span struct { ...@@ -116,10 +116,6 @@ type span struct {
// A Tokenizer returns a stream of HTML Tokens. // A Tokenizer returns a stream of HTML Tokens.
type Tokenizer struct { type Tokenizer struct {
// If ReturnComments is set, Next returns comment tokens;
// otherwise it skips over comments (default).
ReturnComments bool
// r is the source of the HTML text. // r is the source of the HTML text.
r io.Reader r io.Reader
// tt is the TokenType of the current token. // tt is the TokenType of the current token.
...@@ -546,17 +542,19 @@ func (z *Tokenizer) readTagAttrVal() { ...@@ -546,17 +542,19 @@ func (z *Tokenizer) readTagAttrVal() {
} }
} }
// next scans the next token and returns its type. // Next scans the next token and returns its type.
func (z *Tokenizer) next() TokenType { func (z *Tokenizer) Next() TokenType {
if z.err != nil { if z.err != nil {
return ErrorToken z.tt = ErrorToken
return z.tt
} }
z.raw.start = z.raw.end z.raw.start = z.raw.end
z.data.start = z.raw.end z.data.start = z.raw.end
z.data.end = z.raw.end z.data.end = z.raw.end
if z.rawTag != "" { if z.rawTag != "" {
z.readRawOrRCDATA() z.readRawOrRCDATA()
return TextToken z.tt = TextToken
return z.tt
} }
z.textIsRaw = false z.textIsRaw = false
...@@ -596,11 +594,13 @@ loop: ...@@ -596,11 +594,13 @@ loop:
if x := z.raw.end - len("<a"); z.raw.start < x { if x := z.raw.end - len("<a"); z.raw.start < x {
z.raw.end = x z.raw.end = x
z.data.end = x z.data.end = x
return TextToken z.tt = TextToken
return z.tt
} }
switch tokenType { switch tokenType {
case StartTagToken: case StartTagToken:
return z.readStartTag() z.tt = z.readStartTag()
return z.tt
case EndTagToken: case EndTagToken:
c = z.readByte() c = z.readByte()
if z.err != nil { if z.err != nil {
...@@ -616,39 +616,31 @@ loop: ...@@ -616,39 +616,31 @@ loop:
} }
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' { if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' {
z.readEndTag() z.readEndTag()
return EndTagToken z.tt = EndTagToken
return z.tt
} }
z.raw.end-- z.raw.end--
z.readUntilCloseAngle() z.readUntilCloseAngle()
return CommentToken z.tt = CommentToken
return z.tt
case CommentToken: case CommentToken:
if c == '!' { if c == '!' {
return z.readMarkupDeclaration() z.tt = z.readMarkupDeclaration()
return z.tt
} }
z.raw.end-- z.raw.end--
z.readUntilCloseAngle() z.readUntilCloseAngle()
return CommentToken z.tt = CommentToken
return z.tt
} }
} }
if z.raw.start < z.raw.end { if z.raw.start < z.raw.end {
z.data.end = z.raw.end z.data.end = z.raw.end
return TextToken z.tt = TextToken
}
return ErrorToken
}
// Next scans the next token and returns its type.
func (z *Tokenizer) Next() TokenType {
for {
z.tt = z.next()
// TODO: remove the ReturnComments option. A tokenizer should
// always return comment tags.
if z.tt == CommentToken && !z.ReturnComments {
continue
}
return z.tt return z.tt
} }
panic("unreachable") z.tt = ErrorToken
return z.tt
} }
// Raw returns the unmodified text of the current token. Calling Next, Token, // Raw returns the unmodified text of the current token. Calling Next, Token,
......
...@@ -424,7 +424,6 @@ func TestTokenizer(t *testing.T) { ...@@ -424,7 +424,6 @@ func TestTokenizer(t *testing.T) {
loop: loop:
for _, tt := range tokenTests { for _, tt := range tokenTests {
z := NewTokenizer(strings.NewReader(tt.html)) z := NewTokenizer(strings.NewReader(tt.html))
z.ReturnComments = true
if tt.golden != "" { if tt.golden != "" {
for i, s := range strings.Split(tt.golden, "$") { for i, s := range strings.Split(tt.golden, "$") {
if z.Next() == ErrorToken { if z.Next() == ErrorToken {
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Primitive HTTP client. See RFC 2616. // HTTP client. See RFC 2616.
//
// This is the high-level Client interface.
// The low-level implementation is in transport.go.
package http package http
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
package http_test package http_test
import ( import (
"crypto/tls"
"fmt" "fmt"
. "http" . "http"
"http/httptest" "http/httptest"
...@@ -292,3 +293,26 @@ func TestClientWrites(t *testing.T) { ...@@ -292,3 +293,26 @@ func TestClientWrites(t *testing.T) {
t.Errorf("Post request did %d Write calls, want 1", writes) t.Errorf("Post request did %d Write calls, want 1", writes)
} }
} }
func TestClientInsecureTransport(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
}))
defer ts.Close()
// TODO(bradfitz): add tests for skipping hostname checks too?
// would require a new cert for testing, and probably
// redundant with these tests.
for _, insecure := range []bool{true, false} {
tr := &Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecure,
},
}
c := &Client{Transport: tr}
_, err := c.Get(ts.URL)
if (err == nil) != insecure {
t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
}
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package http provides HTTP client and server implementations.
Get, Head, Post, and PostForm make HTTP requests:
resp, err := http.Get("http://example.com/")
...
resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf)
...
resp, err := http.PostForm("http://example.com/form",
url.Values{"key": {"Value"}, "id": {"123"}})
The client must close the response body when finished with it:
resp, err := http.Get("http://example.com/")
if err != nil {
// handle error
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
// ...
For control over HTTP client headers, redirect policy, and other
settings, create a Client:
client := &http.Client{
CheckRedirect: redirectPolicyFunc,
}
resp, err := client.Get("http://example.com")
// ...
req := http.NewRequest("GET", "http://example.com", nil)
req.Header.Add("If-None-Match", `W/"wyzzy"`)
resp, err := client.Do(req)
// ...
For control over proxies, TLS configuration, keep-alives,
compression, and other settings, create a Transport:
tr := &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: pool},
DisableCompression: true,
}
client := &http.Client{Transport: tr}
resp, err := client.Get("https://example.com")
Clients and Transports are safe for concurrent use by multiple
goroutines and for efficiency should only be created once and re-used.
ListenAndServe starts an HTTP server with a given address and handler.
The handler is usually nil, which means to use DefaultServeMux.
Handle and HandleFunc add handlers to DefaultServeMux:
http.Handle("/foo", fooHandler)
http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.RawPath))
})
log.Fatal(http.ListenAndServe(":8080", nil))
More control over the server's behavior is available by creating a
custom Server:
s := &http.Server{
Addr: ":8080",
Handler: myHandler,
ReadTimeout: 10e9,
WriteTimeout: 10e9,
MaxHeaderBytes: 1 << 20,
}
log.Fatal(s.ListenAndServe())
*/
package http
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
// HTTP Request reading and parsing. // HTTP Request reading and parsing.
// Package http implements parsing of HTTP requests, replies, and URLs and
// provides an extensible HTTP server and a basic HTTP client.
package http package http
import ( import (
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// HTTP client implementation. See RFC 2616.
//
// This is the low-level Transport implementation of RoundTripper.
// The high-level interface is in client.go.
package http package http
import ( import (
...@@ -357,9 +362,11 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { ...@@ -357,9 +362,11 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
if err = conn.(*tls.Conn).Handshake(); err != nil { if err = conn.(*tls.Conn).Handshake(); err != nil {
return nil, err return nil, err
} }
if t.TLSClientConfig == nil || !t.TLSClientConfig.InsecureSkipVerify {
if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil { if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil {
return nil, err return nil, err
} }
}
pconn.conn = conn pconn.conn = conn
} }
......
...@@ -11,9 +11,6 @@ import ( ...@@ -11,9 +11,6 @@ import (
) )
func setKernelSpecificSockopt(s syscall.Handle, f int) { func setKernelSpecificSockopt(s syscall.Handle, f int) {
// Allow reuse of recently-used addresses and ports.
syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
// Allow broadcast. // Allow broadcast.
syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment