Commit 8039ca76 by Ian Lance Taylor

Update to current version of Go library.

From-SVN: r171427
parent 7114321e
94d654be2064 31d7feb9281b
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
"hash/crc32" "hash/crc32"
"encoding/binary" "encoding/binary"
"io" "io"
"io/ioutil"
"os" "os"
) )
...@@ -109,7 +110,7 @@ func (f *File) Open() (rc io.ReadCloser, err os.Error) { ...@@ -109,7 +110,7 @@ func (f *File) Open() (rc io.ReadCloser, err os.Error) {
r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size) r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
switch f.Method { switch f.Method {
case 0: // store (no compression) case 0: // store (no compression)
rc = nopCloser{r} rc = ioutil.NopCloser(r)
case 8: // DEFLATE case 8: // DEFLATE
rc = flate.NewReader(r) rc = flate.NewReader(r)
default: default:
...@@ -147,12 +148,6 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) { ...@@ -147,12 +148,6 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
func (r *checksumReader) Close() os.Error { return r.rc.Close() } func (r *checksumReader) Close() os.Error { return r.rc.Close() }
type nopCloser struct {
io.Reader
}
func (f nopCloser) Close() os.Error { return nil }
func readFileHeader(f *File, r io.Reader) (err os.Error) { func readFileHeader(f *File, r io.Reader) (err os.Error) {
defer func() { defer func() {
if rerr, ok := recover().(os.Error); ok { if rerr, ok := recover().(os.Error); ok {
......
...@@ -8,6 +8,7 @@ package big ...@@ -8,6 +8,7 @@ package big
import ( import (
"fmt" "fmt"
"os"
"rand" "rand"
) )
...@@ -393,62 +394,19 @@ func (z *Int) SetString(s string, base int) (*Int, bool) { ...@@ -393,62 +394,19 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
} }
// SetBytes interprets b as the bytes of a big-endian, unsigned integer and // SetBytes interprets buf as the bytes of a big-endian unsigned
// sets z to that value. // integer, sets z to that value, and returns z.
func (z *Int) SetBytes(b []byte) *Int { func (z *Int) SetBytes(buf []byte) *Int {
const s = _S z.abs = z.abs.setBytes(buf)
z.abs = z.abs.make((len(b) + s - 1) / s)
j := 0
for len(b) >= s {
var w Word
for i := s; i > 0; i-- {
w <<= 8
w |= Word(b[len(b)-i])
}
z.abs[j] = w
j++
b = b[0 : len(b)-s]
}
if len(b) > 0 {
var w Word
for i := len(b); i > 0; i-- {
w <<= 8
w |= Word(b[len(b)-i])
}
z.abs[j] = w
}
z.abs = z.abs.norm()
z.neg = false z.neg = false
return z return z
} }
// Bytes returns the absolute value of x as a big-endian byte array. // Bytes returns the absolute value of z as a big-endian byte slice.
func (z *Int) Bytes() []byte { func (z *Int) Bytes() []byte {
const s = _S buf := make([]byte, len(z.abs)*_S)
b := make([]byte, len(z.abs)*s) return buf[z.abs.bytes(buf):]
for i, w := range z.abs {
wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s]
for j := s - 1; j >= 0; j-- {
wordBytes[j] = byte(w)
w >>= 8
}
}
i := 0
for i < len(b) && b[i] == 0 {
i++
}
return b[i:]
} }
...@@ -739,3 +697,34 @@ func (z *Int) Not(x *Int) *Int { ...@@ -739,3 +697,34 @@ func (z *Int) Not(x *Int) *Int {
z.neg = true // z cannot be zero if x is positive z.neg = true // z cannot be zero if x is positive
return z return z
} }
// Gob codec version. Permits backward-compatible changes to the encoding.
const version byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Int) GobEncode() ([]byte, os.Error) {
buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit
i := z.abs.bytes(buf) - 1 // i >= 0
b := version << 1 // make space for sign bit
if z.neg {
b |= 1
}
buf[i] = b
return buf[i:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Int) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Int.GobDecode: no data")
}
b := buf[0]
if b>>1 != version {
return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
}
z.neg = b&1 != 0
z.abs = z.abs.setBytes(buf[1:])
return nil
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"gob"
"testing" "testing"
"testing/quick" "testing/quick"
) )
...@@ -1053,3 +1054,41 @@ func TestModInverse(t *testing.T) { ...@@ -1053,3 +1054,41 @@ func TestModInverse(t *testing.T) {
} }
} }
} }
var gobEncodingTests = []string{
"0",
"1",
"2",
"10",
"42",
"1234567890",
"298472983472983471903246121093472394872319615612417471234712061",
}
func TestGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for i, test := range gobEncodingTests {
for j := 0; j < 2; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j == 0 {
stest = "-" + test
}
var tx Int
tx.SetString(stest, 10)
if err := enc.Encode(&tx); err != nil {
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
}
var rx Int
if err := dec.Decode(&rx); err != nil {
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
}
if rx.Cmp(&tx) != 0 {
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
}
}
}
}
...@@ -1065,3 +1065,50 @@ NextRandom: ...@@ -1065,3 +1065,50 @@ NextRandom:
return true return true
} }
// bytes writes the value of z into buf using big-endian encoding.
// len(buf) must be >= len(z)*_S. The value of z is encoded in the
// slice buf[i:]. The number i of unused bytes at the beginning of
// buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
buf[i] = byte(d)
d >>= 8
}
}
for i < len(buf) && buf[i] == 0 {
i++
}
return
}
// setBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z nat) setBytes(buf []byte) nat {
z = z.make((len(buf) + _S - 1) / _S)
k := 0
s := uint(0)
var d Word
for i := len(buf); i > 0; i-- {
d |= Word(buf[i-1]) << s
if s += 8; s == _S*8 {
z[k] = d
k++
s = 0
d = 0
}
}
if k < len(z) {
z[k] = d
}
return z.norm()
}
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package bufio package bufio_test
import ( import (
. "bufio"
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
...@@ -502,9 +503,8 @@ func TestWriteString(t *testing.T) { ...@@ -502,9 +503,8 @@ func TestWriteString(t *testing.T) {
b.WriteString("7890") // easy after flush b.WriteString("7890") // easy after flush
b.WriteString("abcdefghijklmnopqrstuvwxy") // hard b.WriteString("abcdefghijklmnopqrstuvwxy") // hard
b.WriteString("z") b.WriteString("z")
b.Flush() if err := b.Flush(); err != nil {
if b.err != nil { t.Error("WriteString", err)
t.Error("WriteString", b.err)
} }
s := "01234567890abcdefghijklmnopqrstuvwxyz" s := "01234567890abcdefghijklmnopqrstuvwxyz"
if string(buf.Bytes()) != s { if string(buf.Bytes()) != s {
......
...@@ -191,9 +191,16 @@ func testSync(t *testing.T, level int, input []byte, name string) { ...@@ -191,9 +191,16 @@ func testSync(t *testing.T, level int, input []byte, name string) {
t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo]) t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
return return
} }
if i == 0 && buf.buf.Len() != 0 { // This test originally checked that after reading
t.Errorf("testSync/%d (%d, %d, %s): extra data after %d", i, level, len(input), name, hi-lo) // the first half of the input, there was nothing left
} // in the read buffer (buf.buf.Len() != 0) but that is
// not necessarily the case: the write Flush may emit
// some extra framing bits that are not necessary
// to process to obtain the first half of the uncompressed
// data. The test ran correctly most of the time, because
// the background goroutine had usually read even
// those extra bits by now, but it's not a useful thing to
// check.
buf.WriteMode() buf.WriteMode()
} }
buf.ReadMode() buf.ReadMode()
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
...@@ -117,16 +118,34 @@ func (devNull) Write(p []byte) (int, os.Error) { ...@@ -117,16 +118,34 @@ func (devNull) Write(p []byte) (int, os.Error) {
return len(p), nil return len(p), nil
} }
func BenchmarkDecoder(b *testing.B) { func benchmarkDecoder(b *testing.B, n int) {
b.StopTimer() b.StopTimer()
b.SetBytes(int64(n))
buf0, _ := ioutil.ReadFile("../testdata/e.txt") buf0, _ := ioutil.ReadFile("../testdata/e.txt")
buf0 = buf0[:10000]
compressed := bytes.NewBuffer(nil) compressed := bytes.NewBuffer(nil)
w := NewWriter(compressed, LSB, 8) w := NewWriter(compressed, LSB, 8)
io.Copy(w, bytes.NewBuffer(buf0)) for i := 0; i < n; i += len(buf0) {
io.Copy(w, bytes.NewBuffer(buf0))
}
w.Close() w.Close()
buf1 := compressed.Bytes() buf1 := compressed.Bytes()
buf0, compressed, w = nil, nil, nil
runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8)) io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8))
} }
} }
func BenchmarkDecoder1e4(b *testing.B) {
benchmarkDecoder(b, 1e4)
}
func BenchmarkDecoder1e5(b *testing.B) {
benchmarkDecoder(b, 1e5)
}
func BenchmarkDecoder1e6(b *testing.B) {
benchmarkDecoder(b, 1e6)
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"runtime"
"testing" "testing"
) )
...@@ -99,13 +100,33 @@ func TestWriter(t *testing.T) { ...@@ -99,13 +100,33 @@ func TestWriter(t *testing.T) {
} }
} }
func BenchmarkEncoder(b *testing.B) { func benchmarkEncoder(b *testing.B, n int) {
b.StopTimer() b.StopTimer()
buf, _ := ioutil.ReadFile("../testdata/e.txt") b.SetBytes(int64(n))
buf0, _ := ioutil.ReadFile("../testdata/e.txt")
buf0 = buf0[:10000]
buf1 := make([]byte, n)
for i := 0; i < n; i += len(buf0) {
copy(buf1[i:], buf0)
}
buf0 = nil
runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
w := NewWriter(devNull{}, LSB, 8) w := NewWriter(devNull{}, LSB, 8)
w.Write(buf) w.Write(buf1)
w.Close() w.Close()
} }
} }
func BenchmarkEncoder1e4(b *testing.B) {
benchmarkEncoder(b, 1e4)
}
func BenchmarkEncoder1e5(b *testing.B) {
benchmarkEncoder(b, 1e5)
}
func BenchmarkEncoder1e6(b *testing.B) {
benchmarkEncoder(b, 1e6)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ecdsa implements the Elliptic Curve Digital Signature Algorithm, as
// defined in FIPS 186-3.
package ecdsa
// References:
// [NSA]: Suite B implementor's guide to FIPS 186-3,
// http://www.nsa.gov/ia/_files/ecdsa.pdf
// [SECG]: SECG, SEC1
// http://www.secg.org/download/aid-780/sec1-v2.pdf
import (
"big"
"crypto/elliptic"
"io"
"os"
)
// PublicKey represents an ECDSA public key.
type PublicKey struct {
*elliptic.Curve
X, Y *big.Int
}
// PrivateKey represents a ECDSA private key.
type PrivateKey struct {
PublicKey
D *big.Int
}
var one = new(big.Int).SetInt64(1)
// randFieldElement returns a random element of the field underlying the given
// curve using the procedure given in [NSA] A.2.1.
func randFieldElement(c *elliptic.Curve, rand io.Reader) (k *big.Int, err os.Error) {
b := make([]byte, c.BitSize/8+8)
_, err = rand.Read(b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(c.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
// GenerateKey generates a public&private key pair.
func GenerateKey(c *elliptic.Curve, rand io.Reader) (priv *PrivateKey, err os.Error) {
k, err := randFieldElement(c, rand)
if err != nil {
return
}
priv = new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return
}
// hashToInt converts a hash value to an integer. There is some disagreement
// about how this is done. [NSA] suggests that this is done in the obvious
// manner, but [SECG] truncates the hash to the bit-length of the curve order
// first. We follow [SECG] because that's what OpenSSL does.
func hashToInt(hash []byte, c *elliptic.Curve) *big.Int {
orderBits := c.N.BitLen()
orderBytes := (orderBits + 7) / 8
if len(hash) > orderBytes {
hash = hash[:orderBytes]
}
ret := new(big.Int).SetBytes(hash)
excess := orderBytes*8 - orderBits
if excess > 0 {
ret.Rsh(ret, uint(excess))
}
return ret
}
// Sign signs an arbitrary length hash (which should be the result of hashing a
// larger message) using the private key, priv. It returns the signature as a
// pair of integers. The security of the private key depends on the entropy of
// rand.
func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err os.Error) {
// See [NSA] 3.4.1
c := priv.PublicKey.Curve
var k, kInv *big.Int
for {
for {
k, err = randFieldElement(c, rand)
if err != nil {
r = nil
return
}
kInv = new(big.Int).ModInverse(k, c.N)
r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
r.Mod(r, priv.Curve.N)
if r.Sign() != 0 {
break
}
}
e := hashToInt(hash, c)
s = new(big.Int).Mul(priv.D, r)
s.Add(s, e)
s.Mul(s, kInv)
s.Mod(s, priv.PublicKey.Curve.N)
if s.Sign() != 0 {
break
}
}
return
}
// Verify verifies the signature in r, s of hash using the public key, pub. It
// returns true iff the signature is valid.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
// See [NSA] 3.4.2
c := pub.Curve
if r.Sign() == 0 || s.Sign() == 0 {
return false
}
if r.Cmp(c.N) >= 0 || s.Cmp(c.N) >= 0 {
return false
}
e := hashToInt(hash, c)
w := new(big.Int).ModInverse(s, c.N)
u1 := e.Mul(e, w)
u2 := w.Mul(r, w)
x1, y1 := c.ScalarBaseMult(u1.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, u2.Bytes())
if x1.Cmp(x2) == 0 {
return false
}
x, _ := c.Add(x1, y1, x2, y2)
x.Mod(x, c.N)
return x.Cmp(r) == 0
}
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
// See http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html // See http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
type Curve struct { type Curve struct {
P *big.Int // the order of the underlying field P *big.Int // the order of the underlying field
N *big.Int // the order of the base point
B *big.Int // the constant of the curve equation B *big.Int // the constant of the curve equation
Gx, Gy *big.Int // (x,y) of the base point Gx, Gy *big.Int // (x,y) of the base point
BitSize int // the size of the underlying field BitSize int // the size of the underlying field
...@@ -315,6 +316,7 @@ func initP224() { ...@@ -315,6 +316,7 @@ func initP224() {
// See FIPS 186-3, section D.2.2 // See FIPS 186-3, section D.2.2
p224 = new(Curve) p224 = new(Curve)
p224.P, _ = new(big.Int).SetString("26959946667150639794667015087019630673557916260026308143510066298881", 10) p224.P, _ = new(big.Int).SetString("26959946667150639794667015087019630673557916260026308143510066298881", 10)
p224.N, _ = new(big.Int).SetString("26959946667150639794667015087019625940457807714424391721682722368061", 10)
p224.B, _ = new(big.Int).SetString("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4", 16) p224.B, _ = new(big.Int).SetString("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4", 16)
p224.Gx, _ = new(big.Int).SetString("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21", 16) p224.Gx, _ = new(big.Int).SetString("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21", 16)
p224.Gy, _ = new(big.Int).SetString("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34", 16) p224.Gy, _ = new(big.Int).SetString("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34", 16)
...@@ -325,6 +327,7 @@ func initP256() { ...@@ -325,6 +327,7 @@ func initP256() {
// See FIPS 186-3, section D.2.3 // See FIPS 186-3, section D.2.3
p256 = new(Curve) p256 = new(Curve)
p256.P, _ = new(big.Int).SetString("115792089210356248762697446949407573530086143415290314195533631308867097853951", 10) p256.P, _ = new(big.Int).SetString("115792089210356248762697446949407573530086143415290314195533631308867097853951", 10)
p256.N, _ = new(big.Int).SetString("115792089210356248762697446949407573529996955224135760342422259061068512044369", 10)
p256.B, _ = new(big.Int).SetString("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16) p256.B, _ = new(big.Int).SetString("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16)
p256.Gx, _ = new(big.Int).SetString("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16) p256.Gx, _ = new(big.Int).SetString("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16)
p256.Gy, _ = new(big.Int).SetString("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16) p256.Gy, _ = new(big.Int).SetString("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16)
...@@ -335,6 +338,7 @@ func initP384() { ...@@ -335,6 +338,7 @@ func initP384() {
// See FIPS 186-3, section D.2.4 // See FIPS 186-3, section D.2.4
p384 = new(Curve) p384 = new(Curve)
p384.P, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319", 10) p384.P, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319", 10)
p384.N, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643", 10)
p384.B, _ = new(big.Int).SetString("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef", 16) p384.B, _ = new(big.Int).SetString("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef", 16)
p384.Gx, _ = new(big.Int).SetString("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", 16) p384.Gx, _ = new(big.Int).SetString("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", 16)
p384.Gy, _ = new(big.Int).SetString("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f", 16) p384.Gy, _ = new(big.Int).SetString("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f", 16)
...@@ -345,6 +349,7 @@ func initP521() { ...@@ -345,6 +349,7 @@ func initP521() {
// See FIPS 186-3, section D.2.5 // See FIPS 186-3, section D.2.5
p521 = new(Curve) p521 = new(Curve)
p521.P, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", 10) p521.P, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", 10)
p521.N, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397655394245057746333217197532963996371363321113864768612440380340372808892707005449", 10)
p521.B, _ = new(big.Int).SetString("051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00", 16) p521.B, _ = new(big.Int).SetString("051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00", 16)
p521.Gx, _ = new(big.Int).SetString("c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", 16) p521.Gx, _ = new(big.Int).SetString("c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", 16)
p521.Gy, _ = new(big.Int).SetString("11839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650", 16) p521.Gy, _ = new(big.Int).SetString("11839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650", 16)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
package packet package packet
import ( import (
"big"
"crypto/aes" "crypto/aes"
"crypto/cast5" "crypto/cast5"
"crypto/cipher" "crypto/cipher"
...@@ -166,10 +167,10 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader, ...@@ -166,10 +167,10 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
return return
} }
// serialiseHeader writes an OpenPGP packet header to w. See RFC 4880, section // serializeHeader writes an OpenPGP packet header to w. See RFC 4880, section
// 4.2. // 4.2.
func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) { func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
var buf [5]byte var buf [6]byte
var n int var n int
buf[0] = 0x80 | 0x40 | byte(ptype) buf[0] = 0x80 | 0x40 | byte(ptype)
...@@ -178,16 +179,16 @@ func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) { ...@@ -178,16 +179,16 @@ func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
n = 2 n = 2
} else if length < 8384 { } else if length < 8384 {
length -= 192 length -= 192
buf[1] = byte(length >> 8) buf[1] = 192 + byte(length>>8)
buf[2] = byte(length) buf[2] = byte(length)
n = 3 n = 3
} else { } else {
buf[0] = 255 buf[1] = 255
buf[1] = byte(length >> 24) buf[2] = byte(length >> 24)
buf[2] = byte(length >> 16) buf[3] = byte(length >> 16)
buf[3] = byte(length >> 8) buf[4] = byte(length >> 8)
buf[4] = byte(length) buf[5] = byte(length)
n = 5 n = 6
} }
_, err = w.Write(buf[:n]) _, err = w.Write(buf[:n])
...@@ -371,7 +372,7 @@ func (cipher CipherFunction) new(key []byte) (block cipher.Block) { ...@@ -371,7 +372,7 @@ func (cipher CipherFunction) new(key []byte) (block cipher.Block) {
// readMPI reads a big integer from r. The bit length returned is the bit // readMPI reads a big integer from r. The bit length returned is the bit
// length that was specified in r. This is preserved so that the integer can be // length that was specified in r. This is preserved so that the integer can be
// reserialised exactly. // reserialized exactly.
func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) { func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
var buf [2]byte var buf [2]byte
_, err = readFull(r, buf[0:]) _, err = readFull(r, buf[0:])
...@@ -385,7 +386,7 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) { ...@@ -385,7 +386,7 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
return return
} }
// writeMPI serialises a big integer to r. // writeMPI serializes a big integer to w.
func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) { func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
_, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)}) _, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
if err == nil { if err == nil {
...@@ -393,3 +394,8 @@ func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) { ...@@ -393,3 +394,8 @@ func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
} }
return return
} }
// writeBig serializes a *big.Int to w.
func writeBig(w io.Writer, i *big.Int) os.Error {
return writeMPI(w, uint16(i.BitLen()), i.Bytes())
}
...@@ -190,3 +190,23 @@ func TestReadHeader(t *testing.T) { ...@@ -190,3 +190,23 @@ func TestReadHeader(t *testing.T) {
} }
} }
} }
func TestSerializeHeader(t *testing.T) {
tag := packetTypePublicKey
lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000}
for _, length := range lengths {
buf := bytes.NewBuffer(nil)
serializeHeader(buf, tag, length)
tag2, length2, _, err := readHeader(buf)
if err != nil {
t.Errorf("length %d, err: %s", length, err)
}
if tag2 != tag {
t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag)
}
if int(length2) != length {
t.Errorf("length %d, length incorrect (got %d)", length, length2)
}
}
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"big" "big"
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
"crypto/dsa"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rsa" "crypto/rsa"
...@@ -134,7 +135,16 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error { ...@@ -134,7 +135,16 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
} }
func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) { func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
// TODO(agl): support DSA and ECDSA private keys. switch pk.PublicKey.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoRSAEncryptOnly:
return pk.parseRSAPrivateKey(data)
case PubKeyAlgoDSA:
return pk.parseDSAPrivateKey(data)
}
panic("impossible")
}
func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) {
rsaPub := pk.PublicKey.PublicKey.(*rsa.PublicKey) rsaPub := pk.PublicKey.PublicKey.(*rsa.PublicKey)
rsaPriv := new(rsa.PrivateKey) rsaPriv := new(rsa.PrivateKey)
rsaPriv.PublicKey = *rsaPub rsaPriv.PublicKey = *rsaPub
...@@ -162,3 +172,22 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) { ...@@ -162,3 +172,22 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
return nil return nil
} }
func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
dsaPub := pk.PublicKey.PublicKey.(*dsa.PublicKey)
dsaPriv := new(dsa.PrivateKey)
dsaPriv.PublicKey = *dsaPub
buf := bytes.NewBuffer(data)
x, _, err := readMPI(buf)
if err != nil {
return
}
dsaPriv.X = new(big.Int).SetBytes(x)
pk.PrivateKey = dsaPriv
pk.Encrypted = false
pk.encryptedData = nil
return nil
}
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"encoding/binary" "encoding/binary"
"fmt"
"hash" "hash"
"io" "io"
"os" "os"
...@@ -178,12 +179,6 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E ...@@ -178,12 +179,6 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
return error.InvalidArgumentError("public key cannot generate signatures") return error.InvalidArgumentError("public key cannot generate signatures")
} }
rsaPublicKey, ok := pk.PublicKey.(*rsa.PublicKey)
if !ok {
// TODO(agl): support DSA and ECDSA keys.
return error.UnsupportedError("non-RSA public key")
}
signed.Write(sig.HashSuffix) signed.Write(sig.HashSuffix)
hashBytes := signed.Sum() hashBytes := signed.Sum()
...@@ -191,11 +186,28 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E ...@@ -191,11 +186,28 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
return error.SignatureError("hash tag doesn't match") return error.SignatureError("hash tag doesn't match")
} }
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.Signature) if pk.PubKeyAlgo != sig.PubKeyAlgo {
if err != nil { return error.InvalidArgumentError("public key and signature use different algorithms")
return error.SignatureError("RSA verification failure")
} }
return nil
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature)
if err != nil {
return error.SignatureError("RSA verification failure")
}
return nil
case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) {
return error.SignatureError("DSA verification failure")
}
return nil
default:
panic("shouldn't happen")
}
panic("unreachable")
} }
// VerifyKeySignature returns nil iff sig is a valid signature, make by this // VerifyKeySignature returns nil iff sig is a valid signature, make by this
...@@ -239,9 +251,21 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er ...@@ -239,9 +251,21 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
return pk.VerifySignature(h, sig) return pk.VerifySignature(h, sig)
} }
// KeyIdString returns the public key's fingerprint in capital hex
// (e.g. "6C7EE1B8621CC013").
func (pk *PublicKey) KeyIdString() string {
return fmt.Sprintf("%X", pk.Fingerprint[12:20])
}
// KeyIdShortString returns the short form of public key's fingerprint
// in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
func (pk *PublicKey) KeyIdShortString() string {
return fmt.Sprintf("%X", pk.Fingerprint[16:20])
}
// A parsedMPI is used to store the contents of a big integer, along with the // A parsedMPI is used to store the contents of a big integer, along with the
// bit length that was specified in the original input. This allows the MPI to // bit length that was specified in the original input. This allows the MPI to
// be reserialised exactly. // be reserialized exactly.
type parsedMPI struct { type parsedMPI struct {
bytes []byte bytes []byte
bitLength uint16 bitLength uint16
......
...@@ -16,9 +16,11 @@ var pubKeyTests = []struct { ...@@ -16,9 +16,11 @@ var pubKeyTests = []struct {
creationTime uint32 creationTime uint32
pubKeyAlgo PublicKeyAlgorithm pubKeyAlgo PublicKeyAlgorithm
keyId uint64 keyId uint64
keyIdString string
keyIdShort string
}{ }{
{rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb}, {rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb, "A34D7E18C20C31BB", "C20C31BB"},
{dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed}, {dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed, "8E8FBE54062F19ED", "062F19ED"},
} }
func TestPublicKeyRead(t *testing.T) { func TestPublicKeyRead(t *testing.T) {
...@@ -46,6 +48,12 @@ func TestPublicKeyRead(t *testing.T) { ...@@ -46,6 +48,12 @@ func TestPublicKeyRead(t *testing.T) {
if pk.KeyId != test.keyId { if pk.KeyId != test.keyId {
t.Errorf("#%d: bad keyid got:%x want:%x", i, pk.KeyId, test.keyId) t.Errorf("#%d: bad keyid got:%x want:%x", i, pk.KeyId, test.keyId)
} }
if g, e := pk.KeyIdString(), test.keyIdString; g != e {
t.Errorf("#%d: bad KeyIdString got:%q want:%q", i, g, e)
}
if g, e := pk.KeyIdShortString(), test.keyIdShort; g != e {
t.Errorf("#%d: bad KeyIdShortString got:%q want:%q", i, g, e)
}
} }
} }
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
package packet package packet
import ( import (
"big"
"crypto" "crypto"
"crypto/dsa"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rand" "crypto/rand"
...@@ -29,7 +31,9 @@ type Signature struct { ...@@ -29,7 +31,9 @@ type Signature struct {
// of bad signed data. // of bad signed data.
HashTag [2]byte HashTag [2]byte
CreationTime uint32 // Unix epoch time CreationTime uint32 // Unix epoch time
Signature []byte
RSASignature []byte
DSASigR, DSASigS *big.Int
// The following are optional so are nil when not included in the // The following are optional so are nil when not included in the
// signature. // signature.
...@@ -66,7 +70,7 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) { ...@@ -66,7 +70,7 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
sig.SigType = SignatureType(buf[0]) sig.SigType = SignatureType(buf[0])
sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1]) sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1])
switch sig.PubKeyAlgo { switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
default: default:
err = error.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo))) err = error.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
return return
...@@ -122,8 +126,20 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) { ...@@ -122,8 +126,20 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
return return
} }
// We have already checked that the public key algorithm is RSA. switch sig.PubKeyAlgo {
sig.Signature, _, err = readMPI(r) case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature, _, err = readMPI(r)
case PubKeyAlgoDSA:
var rBytes, sBytes []byte
rBytes, _, err = readMPI(r)
sig.DSASigR = new(big.Int).SetBytes(rBytes)
if err == nil {
sBytes, _, err = readMPI(r)
sig.DSASigS = new(big.Int).SetBytes(sBytes)
}
default:
panic("unreachable")
}
return return
} }
...@@ -316,8 +332,8 @@ func subpacketLengthLength(length int) int { ...@@ -316,8 +332,8 @@ func subpacketLengthLength(length int) int {
return 5 return 5
} }
// serialiseSubpacketLength marshals the given length into to. // serializeSubpacketLength marshals the given length into to.
func serialiseSubpacketLength(to []byte, length int) int { func serializeSubpacketLength(to []byte, length int) int {
if length < 192 { if length < 192 {
to[0] = byte(length) to[0] = byte(length)
return 1 return 1
...@@ -336,7 +352,7 @@ func serialiseSubpacketLength(to []byte, length int) int { ...@@ -336,7 +352,7 @@ func serialiseSubpacketLength(to []byte, length int) int {
return 5 return 5
} }
// subpacketsLength returns the serialised length, in bytes, of the given // subpacketsLength returns the serialized length, in bytes, of the given
// subpackets. // subpackets.
func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) { func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
for _, subpacket := range subpackets { for _, subpacket := range subpackets {
...@@ -349,11 +365,11 @@ func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) { ...@@ -349,11 +365,11 @@ func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
return return
} }
// serialiseSubpackets marshals the given subpackets into to. // serializeSubpackets marshals the given subpackets into to.
func serialiseSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) { func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
for _, subpacket := range subpackets { for _, subpacket := range subpackets {
if subpacket.hashed == hashed { if subpacket.hashed == hashed {
n := serialiseSubpacketLength(to, len(subpacket.contents)+1) n := serializeSubpacketLength(to, len(subpacket.contents)+1)
to[n] = byte(subpacket.subpacketType) to[n] = byte(subpacket.subpacketType)
to = to[1+n:] to = to[1+n:]
n = copy(to, subpacket.contents) n = copy(to, subpacket.contents)
...@@ -381,7 +397,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) { ...@@ -381,7 +397,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
} }
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8) sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
sig.HashSuffix[5] = byte(hashedSubpacketsLen) sig.HashSuffix[5] = byte(hashedSubpacketsLen)
serialiseSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true) serializeSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
trailer := sig.HashSuffix[l:] trailer := sig.HashSuffix[l:]
trailer[0] = 4 trailer[0] = 4
trailer[1] = 0xff trailer[1] = 0xff
...@@ -392,32 +408,66 @@ func (sig *Signature) buildHashSuffix() (err os.Error) { ...@@ -392,32 +408,66 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
return return
} }
// SignRSA signs a message with an RSA private key. The hash, h, must contain func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error) {
// the hash of message to be signed and will be mutated by this function.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
err = sig.buildHashSuffix() err = sig.buildHashSuffix()
if err != nil { if err != nil {
return return
} }
h.Write(sig.HashSuffix) h.Write(sig.HashSuffix)
digest := h.Sum() digest = h.Sum()
copy(sig.HashTag[:], digest) copy(sig.HashTag[:], digest)
sig.Signature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
return return
} }
// Serialize marshals sig to w. SignRSA must have been called first. // SignRSA signs a message with an RSA private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
digest, err := sig.signPrepareHash(h)
if err != nil {
return
}
sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
return
}
// SignDSA signs a message with a DSA private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) {
digest, err := sig.signPrepareHash(h)
if err != nil {
return
}
sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest)
return
}
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
func (sig *Signature) Serialize(w io.Writer) (err os.Error) { func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if sig.Signature == nil { if sig.RSASignature == nil && sig.DSASigR == nil {
return error.InvalidArgumentError("Signature: need to call SignRSA before Serialize") return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
}
sigLength := 0
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sigLength = len(sig.RSASignature)
case PubKeyAlgoDSA:
sigLength = 2 /* MPI length */
sigLength += (sig.DSASigR.BitLen() + 7) / 8
sigLength += 2 /* MPI length */
sigLength += (sig.DSASigS.BitLen() + 7) / 8
default:
panic("impossible")
} }
unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false) unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
length := len(sig.HashSuffix) - 6 /* trailer not included */ + length := len(sig.HashSuffix) - 6 /* trailer not included */ +
2 /* length of unhashed subpackets */ + unhashedSubpacketsLen + 2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
2 /* hash tag */ + 2 /* length of signature MPI */ + len(sig.Signature) 2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength
err = serialiseHeader(w, packetTypeSignature, length) err = serializeHeader(w, packetTypeSignature, length)
if err != nil { if err != nil {
return return
} }
...@@ -430,7 +480,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) { ...@@ -430,7 +480,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen) unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen)
unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8) unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8)
unhashedSubpackets[1] = byte(unhashedSubpacketsLen) unhashedSubpackets[1] = byte(unhashedSubpacketsLen)
serialiseSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false) serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
_, err = w.Write(unhashedSubpackets) _, err = w.Write(unhashedSubpackets)
if err != nil { if err != nil {
...@@ -440,7 +490,19 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) { ...@@ -440,7 +490,19 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if err != nil { if err != nil {
return return
} }
return writeMPI(w, 8*uint16(len(sig.Signature)), sig.Signature)
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature)
case PubKeyAlgoDSA:
err = writeBig(w, sig.DSASigR)
if err == nil {
err = writeBig(w, sig.DSASigS)
}
default:
panic("impossible")
}
return
} }
// outputSubpacket represents a subpacket to be marshaled. // outputSubpacket represents a subpacket to be marshaled.
......
...@@ -6,6 +6,7 @@ package openpgp ...@@ -6,6 +6,7 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/dsa"
"crypto/openpgp/armor" "crypto/openpgp/armor"
"crypto/openpgp/error" "crypto/openpgp/error"
"crypto/openpgp/packet" "crypto/openpgp/packet"
...@@ -39,7 +40,7 @@ func DetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error { ...@@ -39,7 +40,7 @@ func DetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
// ArmoredDetachSignText signs message (after canonicalising the line endings) // ArmoredDetachSignText signs message (after canonicalising the line endings)
// with the private key from signer (which must already have been decrypted) // with the private key from signer (which must already have been decrypted)
// and writes an armored signature to w. // and writes an armored signature to w.
func SignTextDetachedArmored(w io.Writer, signer *Entity, message io.Reader) os.Error { func ArmoredDetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
return armoredDetachSign(w, signer, message, packet.SigTypeText) return armoredDetachSign(w, signer, message, packet.SigTypeText)
} }
...@@ -80,6 +81,9 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S ...@@ -80,6 +81,9 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSASignOnly: case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSASignOnly:
priv := signer.PrivateKey.PrivateKey.(*rsa.PrivateKey) priv := signer.PrivateKey.PrivateKey.(*rsa.PrivateKey)
err = sig.SignRSA(h, priv) err = sig.SignRSA(h, priv)
case packet.PubKeyAlgoDSA:
priv := signer.PrivateKey.PrivateKey.(*dsa.PrivateKey)
err = sig.SignDSA(h, priv)
default: default:
err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo))) err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
} }
......
...@@ -18,7 +18,7 @@ func TestSignDetached(t *testing.T) { ...@@ -18,7 +18,7 @@ func TestSignDetached(t *testing.T) {
t.Error(err) t.Error(err)
} }
testDetachedSignature(t, kring, out, signedInput, "check") testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
} }
func TestSignTextDetached(t *testing.T) { func TestSignTextDetached(t *testing.T) {
...@@ -30,5 +30,17 @@ func TestSignTextDetached(t *testing.T) { ...@@ -30,5 +30,17 @@ func TestSignTextDetached(t *testing.T) {
t.Error(err) t.Error(err)
} }
testDetachedSignature(t, kring, out, signedInput, "check") testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
}
func TestSignDetachedDSA(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(dsaTestKeyPrivateHex))
out := bytes.NewBuffer(nil)
message := bytes.NewBufferString(signedInput)
err := DetachSign(out, kring[0], message)
if err != nil {
t.Error(err)
}
testDetachedSignature(t, kring, out, signedInput, "check", testKey3KeyId)
} }
...@@ -7,6 +7,7 @@ package tls ...@@ -7,6 +7,7 @@ package tls
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/x509"
"io" "io"
"io/ioutil" "io/ioutil"
"sync" "sync"
...@@ -95,6 +96,9 @@ type ConnectionState struct { ...@@ -95,6 +96,9 @@ type ConnectionState struct {
HandshakeComplete bool HandshakeComplete bool
CipherSuite uint16 CipherSuite uint16
NegotiatedProtocol string NegotiatedProtocol string
// the certificate chain that was presented by the other side
PeerCertificates []*x509.Certificate
} }
// A Config structure is used to configure a TLS client or server. After one // A Config structure is used to configure a TLS client or server. After one
......
...@@ -762,6 +762,7 @@ func (c *Conn) ConnectionState() ConnectionState { ...@@ -762,6 +762,7 @@ func (c *Conn) ConnectionState() ConnectionState {
if c.handshakeComplete { if c.handshakeComplete {
state.NegotiatedProtocol = c.clientProtocol state.NegotiatedProtocol = c.clientProtocol
state.CipherSuite = c.cipherSuite state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates
} }
return state return state
...@@ -776,15 +777,6 @@ func (c *Conn) OCSPResponse() []byte { ...@@ -776,15 +777,6 @@ func (c *Conn) OCSPResponse() []byte {
return c.ocspResponse return c.ocspResponse
} }
// PeerCertificates returns the certificate chain that was presented by the
// other side.
func (c *Conn) PeerCertificates() []*x509.Certificate {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.peerCertificates
}
// VerifyHostname checks that the peer certificate chain is valid for // VerifyHostname checks that the peer certificate chain is valid for
// connecting to host. If so, it returns nil; if not, it returns an os.Error // connecting to host. If so, it returns nil; if not, it returns an os.Error
// describing the problem. // describing the problem.
......
...@@ -25,7 +25,7 @@ func main() { ...@@ -25,7 +25,7 @@ func main() {
priv, err := rsa.GenerateKey(rand.Reader, 1024) priv, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil { if err != nil {
log.Exitf("failed to generate private key: %s", err) log.Fatalf("failed to generate private key: %s", err)
return return
} }
...@@ -46,13 +46,13 @@ func main() { ...@@ -46,13 +46,13 @@ func main() {
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil { if err != nil {
log.Exitf("Failed to create certificate: %s", err) log.Fatalf("Failed to create certificate: %s", err)
return return
} }
certOut, err := os.Open("cert.pem", os.O_WRONLY|os.O_CREAT, 0644) certOut, err := os.Open("cert.pem", os.O_WRONLY|os.O_CREAT, 0644)
if err != nil { if err != nil {
log.Exitf("failed to open cert.pem for writing: %s", err) log.Fatalf("failed to open cert.pem for writing: %s", err)
return return
} }
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
......
...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) { ...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on OS X") return nil, os.NewError("debug/proc not implemented on OS X")
} }
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) { func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0) return Attach(0)
} }
...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) { ...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on FreeBSD") return nil, os.NewError("debug/proc not implemented on FreeBSD")
} }
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) { func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0) return Attach(0)
} }
...@@ -1279,25 +1279,31 @@ func Attach(pid int) (Process, os.Error) { ...@@ -1279,25 +1279,31 @@ func Attach(pid int) (Process, os.Error) {
return p, nil return p, nil
} }
// ForkExec forks the current process and execs argv0, stopping the // StartProcess forks the current process and execs argv0, stopping the
// new process after the exec syscall. See os.ForkExec for additional // new process after the exec syscall. See os.StartProcess for additional
// details. // details.
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) { func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
sysattr := &syscall.ProcAttr{
Dir: attr.Dir,
Env: attr.Env,
Ptrace: true,
}
p := newProcess(-1) p := newProcess(-1)
// Create array of integer (system) fds. // Create array of integer (system) fds.
intfd := make([]int, len(fd)) intfd := make([]int, len(attr.Files))
for i, f := range fd { for i, f := range attr.Files {
if f == nil { if f == nil {
intfd[i] = -1 intfd[i] = -1
} else { } else {
intfd[i] = f.Fd() intfd[i] = f.Fd()
} }
} }
sysattr.Files = intfd
// Fork from the monitor thread so we get the right tracer pid. // Fork from the monitor thread so we get the right tracer pid.
err := p.do(func() os.Error { err := p.do(func() os.Error {
pid, errno := syscall.PtraceForkExec(argv0, argv, envv, dir, intfd) pid, _, errno := syscall.StartProcess(argv0, argv, sysattr)
if errno != 0 { if errno != 0 {
return &os.PathError{"fork/exec", argv0, os.Errno(errno)} return &os.PathError{"fork/exec", argv0, os.Errno(errno)}
} }
......
...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) { ...@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on windows") return nil, os.NewError("debug/proc not implemented on windows")
} }
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) { func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0) return Attach(0)
} }
...@@ -75,17 +75,19 @@ func modeToFiles(mode, fd int) (*os.File, *os.File, os.Error) { ...@@ -75,17 +75,19 @@ func modeToFiles(mode, fd int) (*os.File, *os.File, os.Error) {
// Run starts the named binary running with // Run starts the named binary running with
// arguments argv and environment envv. // arguments argv and environment envv.
// If the dir argument is not empty, the child changes
// into the directory before executing the binary.
// It returns a pointer to a new Cmd representing // It returns a pointer to a new Cmd representing
// the command or an error. // the command or an error.
// //
// The parameters stdin, stdout, and stderr // The arguments stdin, stdout, and stderr
// specify how to handle standard input, output, and error. // specify how to handle standard input, output, and error.
// The choices are DevNull (connect to /dev/null), // The choices are DevNull (connect to /dev/null),
// PassThrough (connect to the current process's standard stream), // PassThrough (connect to the current process's standard stream),
// Pipe (connect to an operating system pipe), and // Pipe (connect to an operating system pipe), and
// MergeWithStdout (only for standard error; use the same // MergeWithStdout (only for standard error; use the same
// file descriptor as was used for standard output). // file descriptor as was used for standard output).
// If a parameter is Pipe, then the corresponding field (Stdin, Stdout, Stderr) // If an argument is Pipe, then the corresponding field (Stdin, Stdout, Stderr)
// of the returned Cmd is the other end of the pipe. // of the returned Cmd is the other end of the pipe.
// Otherwise the field in Cmd is nil. // Otherwise the field in Cmd is nil.
func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int) (c *Cmd, err os.Error) { func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int) (c *Cmd, err os.Error) {
...@@ -105,7 +107,7 @@ func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int ...@@ -105,7 +107,7 @@ func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int
} }
// Run command. // Run command.
c.Process, err = os.StartProcess(name, argv, envv, dir, fd[0:]) c.Process, err = os.StartProcess(name, argv, &os.ProcAttr{Dir: dir, Files: fd[:], Env: envv})
if err != nil { if err != nil {
goto Error goto Error
} }
......
...@@ -118,3 +118,55 @@ func TestAddEnvVar(t *testing.T) { ...@@ -118,3 +118,55 @@ func TestAddEnvVar(t *testing.T) {
t.Fatal("close:", err) t.Fatal("close:", err)
} }
} }
var tryargs = []string{
`2`,
`2 `,
"2 \t",
`2" "`,
`2 ab `,
`2 "ab" `,
`2 \ `,
`2 \\ `,
`2 \" `,
`2 \`,
`2\`,
`2"`,
`2\"`,
`2 "`,
`2 \"`,
``,
`2 ^ `,
`2 \^`,
}
func TestArgs(t *testing.T) {
for _, a := range tryargs {
argv := []string{
"awk",
`BEGIN{printf("%s|%s|%s",ARGV[1],ARGV[2],ARGV[3])}`,
"/dev/null",
a,
"EOF",
}
exe, err := LookPath(argv[0])
if err != nil {
t.Fatal("run:", err)
}
cmd, err := Run(exe, argv, nil, "", DevNull, Pipe, DevNull)
if err != nil {
t.Fatal("run:", err)
}
buf, err := ioutil.ReadAll(cmd.Stdout)
if err != nil {
t.Fatal("read:", err)
}
expect := "/dev/null|" + a + "|EOF"
if string(buf) != expect {
t.Errorf("read: got %q expect %q", buf, expect)
}
if err = cmd.Close(); err != nil {
t.Fatal("close:", err)
}
}
}
...@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) { ...@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) {
case *ast.SwitchStmt: case *ast.SwitchStmt:
a.compileSwitchStmt(s) a.compileSwitchStmt(s)
case *ast.TypeCaseClause:
notimpl = true
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
notimpl = true notimpl = true
...@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
a.diagAt(clause.Pos(), "switch statement must contain case clauses") a.diagAt(clause.Pos(), "switch statement must contain case clauses")
continue continue
} }
if clause.Values == nil { if clause.List == nil {
if hasDefault { if hasDefault {
a.diagAt(clause.Pos(), "switch statement contains more than one default case") a.diagAt(clause.Pos(), "switch statement contains more than one default case")
} }
hasDefault = true hasDefault = true
} else { } else {
ncases += len(clause.Values) ncases += len(clause.List)
} }
} }
...@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
if !ok { if !ok {
continue continue
} }
for _, v := range clause.Values { for _, v := range clause.List {
e := condbc.compileExpr(condbc.block, false, v) e := condbc.compileExpr(condbc.block, false, v)
switch { switch {
case e == nil: case e == nil:
...@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) { ...@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
// Save jump PC's // Save jump PC's
pc := a.nextPC() pc := a.nextPC()
if clause.Values != nil { if clause.List != nil {
for _ = range clause.Values { for _ = range clause.List {
casePCs[i] = &pc casePCs[i] = &pc
i++ i++
} }
......
...@@ -27,7 +27,7 @@ var stmtTests = []test{ ...@@ -27,7 +27,7 @@ var stmtTests = []test{
CErr("i, u := 1, 2", atLeastOneDecl), CErr("i, u := 1, 2", atLeastOneDecl),
Val2("i, x := 2, f", "i", 2, "x", 1.0), Val2("i, x := 2, f", "i", 2, "x", 1.0),
// Various errors // Various errors
CErr("1 := 2", "left side of := must be a name"), CErr("1 := 2", "expected identifier"),
CErr("c, a := 1, 1", "cannot assign"), CErr("c, a := 1, 1", "cannot assign"),
// Unpacking // Unpacking
Val2("x, y := oneTwo()", "x", 1, "y", 2), Val2("x, y := oneTwo()", "x", 1, "y", 2),
......
...@@ -160,7 +160,7 @@ func cmdLoad(args []byte) os.Error { ...@@ -160,7 +160,7 @@ func cmdLoad(args []byte) os.Error {
} else { } else {
fname = parts[0] fname = parts[0]
} }
tproc, err = proc.ForkExec(fname, parts, os.Environ(), "", []*os.File{os.Stdin, os.Stdout, os.Stderr}) tproc, err = proc.StartProcess(fname, parts, &os.ProcAttr{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}})
if err != nil { if err != nil {
return err return err
} }
......
...@@ -269,7 +269,7 @@ func Iter() <-chan KeyValue { ...@@ -269,7 +269,7 @@ func Iter() <-chan KeyValue {
} }
func expvarHandler(w http.ResponseWriter, r *http.Request) { func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.SetHeader("content-type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n") fmt.Fprintf(w, "{\n")
first := true first := true
for name, value := range vars { for name, value := range vars {
......
...@@ -56,7 +56,7 @@ ...@@ -56,7 +56,7 @@
flag.Bool(...) // global options flag.Bool(...) // global options
flag.Parse() // parse leading command flag.Parse() // parse leading command
subcmd := flag.Args(0) subcmd := flag.Arg[0]
switch subcmd { switch subcmd {
// add per-subcommand options // add per-subcommand options
} }
...@@ -68,6 +68,7 @@ package flag ...@@ -68,6 +68,7 @@ package flag
import ( import (
"fmt" "fmt"
"os" "os"
"sort"
"strconv" "strconv"
) )
...@@ -205,16 +206,34 @@ type allFlags struct { ...@@ -205,16 +206,34 @@ type allFlags struct {
var flags *allFlags var flags *allFlags
// VisitAll visits the flags, calling fn for each. It visits all flags, even those not set. // sortFlags returns the flags as a slice in lexicographical sorted order.
func sortFlags(flags map[string]*Flag) []*Flag {
list := make(sort.StringArray, len(flags))
i := 0
for _, f := range flags {
list[i] = f.Name
i++
}
list.Sort()
result := make([]*Flag, len(list))
for i, name := range list {
result[i] = flags[name]
}
return result
}
// VisitAll visits the flags in lexicographical order, calling fn for each.
// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) { func VisitAll(fn func(*Flag)) {
for _, f := range flags.formal { for _, f := range sortFlags(flags.formal) {
fn(f) fn(f)
} }
} }
// Visit visits the flags, calling fn for each. It visits only those flags that have been set. // Visit visits the flags in lexicographical order, calling fn for each.
// It visits only those flags that have been set.
func Visit(fn func(*Flag)) { func Visit(fn func(*Flag)) {
for _, f := range flags.actual { for _, f := range sortFlags(flags.actual) {
fn(f) fn(f)
} }
} }
...@@ -260,7 +279,9 @@ var Usage = func() { ...@@ -260,7 +279,9 @@ var Usage = func() {
var panicOnError = false var panicOnError = false
func fail() { // failf prints to standard error a formatted error and Usage, and then exits the program.
func failf(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
Usage() Usage()
if panicOnError { if panicOnError {
panic("flag parse error") panic("flag parse error")
...@@ -268,6 +289,7 @@ func fail() { ...@@ -268,6 +289,7 @@ func fail() {
os.Exit(2) os.Exit(2)
} }
// NFlag returns the number of flags that have been set.
func NFlag() int { return len(flags.actual) } func NFlag() int { return len(flags.actual) }
// Arg returns the i'th command-line argument. Arg(0) is the first remaining argument // Arg returns the i'th command-line argument. Arg(0) is the first remaining argument
...@@ -415,8 +437,7 @@ func (f *allFlags) parseOne() (ok bool) { ...@@ -415,8 +437,7 @@ func (f *allFlags) parseOne() (ok bool) {
} }
name := s[num_minuses:] name := s[num_minuses:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' { if len(name) == 0 || name[0] == '-' || name[0] == '=' {
fmt.Fprintln(os.Stderr, "bad flag syntax:", s) failf("bad flag syntax: %s\n", s)
fail()
} }
// it's a flag. does it have an argument? // it's a flag. does it have an argument?
...@@ -434,14 +455,12 @@ func (f *allFlags) parseOne() (ok bool) { ...@@ -434,14 +455,12 @@ func (f *allFlags) parseOne() (ok bool) {
m := flags.formal m := flags.formal
flag, alreadythere := m[name] // BUG flag, alreadythere := m[name] // BUG
if !alreadythere { if !alreadythere {
fmt.Fprintf(os.Stderr, "flag provided but not defined: -%s\n", name) failf("flag provided but not defined: -%s\n", name)
fail()
} }
if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
if has_value { if has_value {
if !fv.Set(value) { if !fv.Set(value) {
fmt.Fprintf(os.Stderr, "invalid boolean value %q for flag: -%s\n", value, name) failf("invalid boolean value %q for flag: -%s\n", value, name)
fail()
} }
} else { } else {
fv.Set("true") fv.Set("true")
...@@ -454,13 +473,11 @@ func (f *allFlags) parseOne() (ok bool) { ...@@ -454,13 +473,11 @@ func (f *allFlags) parseOne() (ok bool) {
value, f.args = f.args[0], f.args[1:] value, f.args = f.args[0], f.args[1:]
} }
if !has_value { if !has_value {
fmt.Fprintf(os.Stderr, "flag needs an argument: -%s\n", name) failf("flag needs an argument: -%s\n", name)
fail()
} }
ok = flag.Value.Set(value) ok = flag.Value.Set(value)
if !ok { if !ok {
fmt.Fprintf(os.Stderr, "invalid value %q for flag: -%s\n", value, name) failf("invalid value %q for flag: -%s\n", value, name)
fail()
} }
} }
flags.actual[name] = flag flags.actual[name] = flag
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
. "flag" . "flag"
"fmt" "fmt"
"os" "os"
"sort"
"testing" "testing"
) )
...@@ -77,6 +78,12 @@ func TestEverything(t *testing.T) { ...@@ -77,6 +78,12 @@ func TestEverything(t *testing.T) {
t.Log(k, *v) t.Log(k, *v)
} }
} }
// Now test they're visited in sort order.
var flagNames []string
Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
if !sort.StringsAreSorted(flagNames) {
t.Errorf("flag names not sorted: %v", flagNames)
}
} }
func TestUsage(t *testing.T) { func TestUsage(t *testing.T) {
......
...@@ -107,7 +107,7 @@ func (f *fmt) writePadding(n int, padding []byte) { ...@@ -107,7 +107,7 @@ func (f *fmt) writePadding(n int, padding []byte) {
} }
// Append b to f.buf, padded on left (w > 0) or right (w < 0 or f.minus) // Append b to f.buf, padded on left (w > 0) or right (w < 0 or f.minus)
// clear flags aftewards. // clear flags afterwards.
func (f *fmt) pad(b []byte) { func (f *fmt) pad(b []byte) {
var padding []byte var padding []byte
var left, right int var left, right int
...@@ -124,7 +124,7 @@ func (f *fmt) pad(b []byte) { ...@@ -124,7 +124,7 @@ func (f *fmt) pad(b []byte) {
} }
// append s to buf, padded on left (w > 0) or right (w < 0 or f.minus). // append s to buf, padded on left (w > 0) or right (w < 0 or f.minus).
// clear flags aftewards. // clear flags afterwards.
func (f *fmt) padString(s string) { func (f *fmt) padString(s string) {
var padding []byte var padding []byte
var left, right int var left, right int
......
...@@ -35,10 +35,15 @@ type ScanState interface { ...@@ -35,10 +35,15 @@ type ScanState interface {
ReadRune() (rune int, size int, err os.Error) ReadRune() (rune int, size int, err os.Error)
// UnreadRune causes the next call to ReadRune to return the same rune. // UnreadRune causes the next call to ReadRune to return the same rune.
UnreadRune() os.Error UnreadRune() os.Error
// Token returns the next space-delimited token from the input. If // Token skips space in the input if skipSpace is true, then returns the
// a width has been specified, the returned token will be no longer // run of Unicode code points c satisfying f(c). If f is nil,
// than the width. // !unicode.IsSpace(c) is used; that is, the token will hold non-space
Token() (token string, err os.Error) // characters. Newlines are treated as space unless the scan operation
// is Scanln, Fscanln or Sscanln, in which case a newline is treated as
// EOF. The returned slice points to shared data that may be overwritten
// by the next call to Token, a call to a Scan function using the ScanState
// as input, or when the calling Scan method returns.
Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
// Width returns the value of the width option and whether it has been set. // Width returns the value of the width option and whether it has been set.
// The unit is Unicode code points. // The unit is Unicode code points.
Width() (wid int, ok bool) Width() (wid int, ok bool)
...@@ -134,7 +139,7 @@ type scanError struct { ...@@ -134,7 +139,7 @@ type scanError struct {
err os.Error err os.Error
} }
const EOF = -1 const eof = -1
// ss is the internal implementation of ScanState. // ss is the internal implementation of ScanState.
type ss struct { type ss struct {
...@@ -202,7 +207,7 @@ func (s *ss) getRune() (rune int) { ...@@ -202,7 +207,7 @@ func (s *ss) getRune() (rune int) {
rune, _, err := s.ReadRune() rune, _, err := s.ReadRune()
if err != nil { if err != nil {
if err == os.EOF { if err == os.EOF {
return EOF return eof
} }
s.error(err) s.error(err)
} }
...@@ -214,7 +219,7 @@ func (s *ss) getRune() (rune int) { ...@@ -214,7 +219,7 @@ func (s *ss) getRune() (rune int) {
// syntax error. // syntax error.
func (s *ss) mustReadRune() (rune int) { func (s *ss) mustReadRune() (rune int) {
rune = s.getRune() rune = s.getRune()
if rune == EOF { if rune == eof {
s.error(io.ErrUnexpectedEOF) s.error(io.ErrUnexpectedEOF)
} }
return return
...@@ -238,7 +243,7 @@ func (s *ss) errorString(err string) { ...@@ -238,7 +243,7 @@ func (s *ss) errorString(err string) {
panic(scanError{os.ErrorString(err)}) panic(scanError{os.ErrorString(err)})
} }
func (s *ss) Token() (tok string, err os.Error) { func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
if se, ok := e.(scanError); ok { if se, ok := e.(scanError); ok {
...@@ -248,10 +253,19 @@ func (s *ss) Token() (tok string, err os.Error) { ...@@ -248,10 +253,19 @@ func (s *ss) Token() (tok string, err os.Error) {
} }
} }
}() }()
tok = s.token() if f == nil {
f = notSpace
}
s.buf.Reset()
tok = s.token(skipSpace, f)
return return
} }
// notSpace is the default scanning function used in Token.
func notSpace(r int) bool {
return !unicode.IsSpace(r)
}
// readRune is a structure to enable reading UTF-8 encoded code points // readRune is a structure to enable reading UTF-8 encoded code points
// from an io.Reader. It is used if the Reader given to the scanner does // from an io.Reader. It is used if the Reader given to the scanner does
// not already implement io.RuneReader. // not already implement io.RuneReader.
...@@ -364,7 +378,7 @@ func (s *ss) free(old ssave) { ...@@ -364,7 +378,7 @@ func (s *ss) free(old ssave) {
func (s *ss) skipSpace(stopAtNewline bool) { func (s *ss) skipSpace(stopAtNewline bool) {
for { for {
rune := s.getRune() rune := s.getRune()
if rune == EOF { if rune == eof {
return return
} }
if rune == '\n' { if rune == '\n' {
...@@ -384,24 +398,27 @@ func (s *ss) skipSpace(stopAtNewline bool) { ...@@ -384,24 +398,27 @@ func (s *ss) skipSpace(stopAtNewline bool) {
} }
} }
// token returns the next space-delimited string from the input. It // token returns the next space-delimited string from the input. It
// skips white space. For Scanln, it stops at newlines. For Scan, // skips white space. For Scanln, it stops at newlines. For Scan,
// newlines are treated as spaces. // newlines are treated as spaces.
func (s *ss) token() string { func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
s.skipSpace(false) if skipSpace {
s.skipSpace(false)
}
// read until white space or newline // read until white space or newline
for { for {
rune := s.getRune() rune := s.getRune()
if rune == EOF { if rune == eof {
break break
} }
if unicode.IsSpace(rune) { if !f(rune) {
s.UnreadRune() s.UnreadRune()
break break
} }
s.buf.WriteRune(rune) s.buf.WriteRune(rune)
} }
return s.buf.String() return s.buf.Bytes()
} }
// typeError indicates that the type of the operand did not match the format // typeError indicates that the type of the operand did not match the format
...@@ -416,7 +433,7 @@ var boolError = os.ErrorString("syntax error scanning boolean") ...@@ -416,7 +433,7 @@ var boolError = os.ErrorString("syntax error scanning boolean")
// If accept is true, it puts the character into the input token. // If accept is true, it puts the character into the input token.
func (s *ss) consume(ok string, accept bool) bool { func (s *ss) consume(ok string, accept bool) bool {
rune := s.getRune() rune := s.getRune()
if rune == EOF { if rune == eof {
return false return false
} }
if strings.IndexRune(ok, rune) >= 0 { if strings.IndexRune(ok, rune) >= 0 {
...@@ -425,7 +442,7 @@ func (s *ss) consume(ok string, accept bool) bool { ...@@ -425,7 +442,7 @@ func (s *ss) consume(ok string, accept bool) bool {
} }
return true return true
} }
if rune != EOF && accept { if rune != eof && accept {
s.UnreadRune() s.UnreadRune()
} }
return false return false
...@@ -434,7 +451,7 @@ func (s *ss) consume(ok string, accept bool) bool { ...@@ -434,7 +451,7 @@ func (s *ss) consume(ok string, accept bool) bool {
// peek reports whether the next character is in the ok string, without consuming it. // peek reports whether the next character is in the ok string, without consuming it.
func (s *ss) peek(ok string) bool { func (s *ss) peek(ok string) bool {
rune := s.getRune() rune := s.getRune()
if rune != EOF { if rune != eof {
s.UnreadRune() s.UnreadRune()
} }
return strings.IndexRune(ok, rune) >= 0 return strings.IndexRune(ok, rune) >= 0
...@@ -729,7 +746,7 @@ func (s *ss) convertString(verb int) (str string) { ...@@ -729,7 +746,7 @@ func (s *ss) convertString(verb int) (str string) {
case 'x': case 'x':
str = s.hexString() str = s.hexString()
default: default:
str = s.token() // %s and %v just return the next word str = string(s.token(true, notSpace)) // %s and %v just return the next word
} }
// Empty strings other than with %q are not OK. // Empty strings other than with %q are not OK.
if len(str) == 0 && verb != 'q' && s.maxWid > 0 { if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
...@@ -797,7 +814,7 @@ func (s *ss) hexDigit(digit int) int { ...@@ -797,7 +814,7 @@ func (s *ss) hexDigit(digit int) int {
// There must be either two hexadecimal digits or a space character in the input. // There must be either two hexadecimal digits or a space character in the input.
func (s *ss) hexByte() (b byte, ok bool) { func (s *ss) hexByte() (b byte, ok bool) {
rune1 := s.getRune() rune1 := s.getRune()
if rune1 == EOF { if rune1 == eof {
return return
} }
if unicode.IsSpace(rune1) { if unicode.IsSpace(rune1) {
...@@ -953,7 +970,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) { ...@@ -953,7 +970,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
if !s.nlIsSpace { if !s.nlIsSpace {
for { for {
rune := s.getRune() rune := s.getRune()
if rune == '\n' || rune == EOF { if rune == '\n' || rune == eof {
break break
} }
if !unicode.IsSpace(rune) { if !unicode.IsSpace(rune) {
...@@ -993,7 +1010,7 @@ func (s *ss) advance(format string) (i int) { ...@@ -993,7 +1010,7 @@ func (s *ss) advance(format string) (i int) {
// There was space in the format, so there should be space (EOF) // There was space in the format, so there should be space (EOF)
// in the input. // in the input.
inputc := s.getRune() inputc := s.getRune()
if inputc == EOF { if inputc == eof {
return return
} }
if !unicode.IsSpace(inputc) { if !unicode.IsSpace(inputc) {
......
...@@ -88,14 +88,15 @@ type FloatTest struct { ...@@ -88,14 +88,15 @@ type FloatTest struct {
type Xs string type Xs string
func (x *Xs) Scan(state ScanState, verb int) os.Error { func (x *Xs) Scan(state ScanState, verb int) os.Error {
tok, err := state.Token() tok, err := state.Token(true, func(r int) bool { return r == verb })
if err != nil { if err != nil {
return err return err
} }
if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(tok) { s := string(tok)
if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(s) {
return os.ErrorString("syntax error for xs") return os.ErrorString("syntax error for xs")
} }
*x = Xs(tok) *x = Xs(s)
return nil return nil
} }
...@@ -113,9 +114,11 @@ func (s *IntString) Scan(state ScanState, verb int) os.Error { ...@@ -113,9 +114,11 @@ func (s *IntString) Scan(state ScanState, verb int) os.Error {
return err return err
} }
if _, err := Fscan(state, &s.s); err != nil { tok, err := state.Token(true, nil)
if err != nil {
return err return err
} }
s.s = string(tok)
return nil return nil
} }
...@@ -331,7 +334,7 @@ var multiTests = []ScanfMultiTest{ ...@@ -331,7 +334,7 @@ var multiTests = []ScanfMultiTest{
{"%c%c%c", "2\u50c2X", args(&i, &j, &k), args('2', '\u50c2', 'X'), ""}, {"%c%c%c", "2\u50c2X", args(&i, &j, &k), args('2', '\u50c2', 'X'), ""},
// Custom scanners. // Custom scanners.
{"%2e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""}, {"%e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""},
{"%4v%s", "12abcd", args(&z, &s), args(IntString{12, "ab"}, "cd"), ""}, {"%4v%s", "12abcd", args(&z, &s), args(IntString{12, "ab"}, "cd"), ""},
// Errors // Errors
...@@ -476,22 +479,12 @@ func verifyInf(str string, t *testing.T) { ...@@ -476,22 +479,12 @@ func verifyInf(str string, t *testing.T) {
} }
} }
func TestInf(t *testing.T) { func TestInf(t *testing.T) {
for _, s := range []string{"inf", "+inf", "-inf", "INF", "-INF", "+INF", "Inf", "-Inf", "+Inf"} { for _, s := range []string{"inf", "+inf", "-inf", "INF", "-INF", "+INF", "Inf", "-Inf", "+Inf"} {
verifyInf(s, t) verifyInf(s, t)
} }
} }
// TODO: there's no conversion from []T to ...T, but we can fake it. These
// functions do the faking. We index the table by the length of the param list.
var fscanf = []func(io.Reader, string, []interface{}) (int, os.Error){
0: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f) },
1: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0]) },
2: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1]) },
3: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1], i[2]) },
}
func testScanfMulti(name string, t *testing.T) { func testScanfMulti(name string, t *testing.T) {
sliceType := reflect.Typeof(make([]interface{}, 1)).(*reflect.SliceType) sliceType := reflect.Typeof(make([]interface{}, 1)).(*reflect.SliceType)
for _, test := range multiTests { for _, test := range multiTests {
...@@ -501,7 +494,7 @@ func testScanfMulti(name string, t *testing.T) { ...@@ -501,7 +494,7 @@ func testScanfMulti(name string, t *testing.T) {
} else { } else {
r = newReader(test.text) r = newReader(test.text)
} }
n, err := fscanf[len(test.in)](r, test.format, test.in) n, err := Fscanf(r, test.format, test.in...)
if err != nil { if err != nil {
if test.err == "" { if test.err == "" {
t.Errorf("got error scanning (%q, %q): %q", test.format, test.text, err) t.Errorf("got error scanning (%q, %q): %q", test.format, test.text, err)
...@@ -830,12 +823,12 @@ func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error ...@@ -830,12 +823,12 @@ func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error
i := 1 i := 1
for ; r != nil; r = r.next { for ; r != nil; r = r.next {
if r.i != i { if r.i != i {
t.Fatal("bad scan: expected %d got %d", i, r.i) t.Fatalf("bad scan: expected %d got %d", i, r.i)
} }
i++ i++
} }
if i-1 != intCount { if i-1 != intCount {
t.Fatal("bad scan count: expected %d got %d", intCount, i-1) t.Fatalf("bad scan count: expected %d got %d", intCount, i-1)
} }
} }
......
...@@ -602,12 +602,12 @@ type ( ...@@ -602,12 +602,12 @@ type (
Else Stmt // else branch; or nil Else Stmt // else branch; or nil
} }
// A CaseClause represents a case of an expression switch statement. // A CaseClause represents a case of an expression or type switch statement.
CaseClause struct { CaseClause struct {
Case token.Pos // position of "case" or "default" keyword Case token.Pos // position of "case" or "default" keyword
Values []Expr // nil means default case List []Expr // list of expressions or types; nil means default case
Colon token.Pos // position of ":" Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil Body []Stmt // statement list; or nil
} }
// A SwitchStmt node represents an expression switch statement. // A SwitchStmt node represents an expression switch statement.
...@@ -618,20 +618,12 @@ type ( ...@@ -618,20 +618,12 @@ type (
Body *BlockStmt // CaseClauses only Body *BlockStmt // CaseClauses only
} }
// A TypeCaseClause represents a case of a type switch statement.
TypeCaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Types []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// An TypeSwitchStmt node represents a type switch statement. // An TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct { TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword Switch token.Pos // position of "switch" keyword
Init Stmt // initalization statement; or nil Init Stmt // initalization statement; or nil
Assign Stmt // x := y.(type) Assign Stmt // x := y.(type) or y.(type)
Body *BlockStmt // TypeCaseClauses only Body *BlockStmt // CaseClauses only
} }
// A CommClause node represents a case of a select statement. // A CommClause node represents a case of a select statement.
...@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace } ...@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace }
func (s *IfStmt) Pos() token.Pos { return s.If } func (s *IfStmt) Pos() token.Pos { return s.If }
func (s *CaseClause) Pos() token.Pos { return s.Case } func (s *CaseClause) Pos() token.Pos { return s.Case }
func (s *SwitchStmt) Pos() token.Pos { return s.Switch } func (s *SwitchStmt) Pos() token.Pos { return s.Switch }
func (s *TypeCaseClause) Pos() token.Pos { return s.Case }
func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch } func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
func (s *CommClause) Pos() token.Pos { return s.Case } func (s *CommClause) Pos() token.Pos { return s.Case }
func (s *SelectStmt) Pos() token.Pos { return s.Select } func (s *SelectStmt) Pos() token.Pos { return s.Select }
...@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos { ...@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos {
} }
return s.Colon + 1 return s.Colon + 1
} }
func (s *SwitchStmt) End() token.Pos { return s.Body.End() } func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeCaseClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() } func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
func (s *CommClause) End() token.Pos { func (s *CommClause) End() token.Pos {
if n := len(s.Body); n > 0 { if n := len(s.Body); n > 0 {
...@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {} ...@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {} func (s *IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {} func (s *CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {} func (s *SwitchStmt) stmtNode() {}
func (s *TypeCaseClause) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {} func (s *TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {} func (s *CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {} func (s *SelectStmt) stmtNode() {}
...@@ -937,11 +921,13 @@ func (d *FuncDecl) declNode() {} ...@@ -937,11 +921,13 @@ func (d *FuncDecl) declNode() {}
// via Doc and Comment fields. // via Doc and Comment fields.
// //
type File struct { type File struct {
Doc *CommentGroup // associated documentation; or nil Doc *CommentGroup // associated documentation; or nil
Package token.Pos // position of "package" keyword Package token.Pos // position of "package" keyword
Name *Ident // package name Name *Ident // package name
Decls []Decl // top-level declarations; or nil Decls []Decl // top-level declarations; or nil
Comments []*CommentGroup // list of all comments in the source file Scope *Scope // package scope
Unresolved []*Ident // unresolved global identifiers
Comments []*CommentGroup // list of all comments in the source file
} }
...@@ -959,7 +945,7 @@ func (f *File) End() token.Pos { ...@@ -959,7 +945,7 @@ func (f *File) End() token.Pos {
// //
type Package struct { type Package struct {
Name string // package name Name string // package name
Scope *Scope // package scope; or nil Scope *Scope // package scope
Files map[string]*File // Go source files by filename Files map[string]*File // Go source files by filename
} }
......
...@@ -425,5 +425,6 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File { ...@@ -425,5 +425,6 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
} }
} }
return &File{doc, pos, NewIdent(pkg.Name), decls, comments} // TODO(gri) need to compute pkgScope and unresolved identifiers!
return &File{doc, pos, NewIdent(pkg.Name), decls, nil, nil, comments}
} }
...@@ -30,15 +30,19 @@ func NotNilFilter(_ string, value reflect.Value) bool { ...@@ -30,15 +30,19 @@ func NotNilFilter(_ string, value reflect.Value) bool {
// Fprint prints the (sub-)tree starting at AST node x to w. // Fprint prints the (sub-)tree starting at AST node x to w.
// If fset != nil, position information is interpreted relative
// to that file set. Otherwise positions are printed as integer
// values (file set specific offsets).
// //
// A non-nil FieldFilter f may be provided to control the output: // A non-nil FieldFilter f may be provided to control the output:
// struct fields for which f(fieldname, fieldvalue) is true are // struct fields for which f(fieldname, fieldvalue) is true are
// are printed; all others are filtered from the output. // are printed; all others are filtered from the output.
// //
func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) { func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n int, err os.Error) {
// setup printer // setup printer
p := printer{ p := printer{
output: w, output: w,
fset: fset,
filter: f, filter: f,
ptrmap: make(map[interface{}]int), ptrmap: make(map[interface{}]int),
last: '\n', // force printing of line number on first line last: '\n', // force printing of line number on first line
...@@ -65,14 +69,15 @@ func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) { ...@@ -65,14 +69,15 @@ func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
// Print prints x to standard output, skipping nil fields. // Print prints x to standard output, skipping nil fields.
// Print(x) is the same as Fprint(os.Stdout, x, NotNilFilter). // Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
func Print(x interface{}) (int, os.Error) { func Print(fset *token.FileSet, x interface{}) (int, os.Error) {
return Fprint(os.Stdout, x, NotNilFilter) return Fprint(os.Stdout, fset, x, NotNilFilter)
} }
type printer struct { type printer struct {
output io.Writer output io.Writer
fset *token.FileSet
filter FieldFilter filter FieldFilter
ptrmap map[interface{}]int // *reflect.PtrValue -> line number ptrmap map[interface{}]int // *reflect.PtrValue -> line number
written int // number of bytes written to output written int // number of bytes written to output
...@@ -137,16 +142,6 @@ func (p *printer) printf(format string, args ...interface{}) { ...@@ -137,16 +142,6 @@ func (p *printer) printf(format string, args ...interface{}) {
// probably be in a different package. // probably be in a different package.
func (p *printer) print(x reflect.Value) { func (p *printer) print(x reflect.Value) {
// Note: This test is only needed because AST nodes
// embed a token.Position, and thus all of them
// understand the String() method (but it only
// applies to the Position field).
// TODO: Should reconsider this AST design decision.
if pos, ok := x.Interface().(token.Position); ok {
p.printf("%s", pos)
return
}
if !NotNilFilter("", x) { if !NotNilFilter("", x) {
p.printf("nil") p.printf("nil")
return return
...@@ -163,6 +158,7 @@ func (p *printer) print(x reflect.Value) { ...@@ -163,6 +158,7 @@ func (p *printer) print(x reflect.Value) {
p.print(key) p.print(key)
p.printf(": ") p.printf(": ")
p.print(v.Elem(key)) p.print(v.Elem(key))
p.printf("\n")
} }
p.indent-- p.indent--
p.printf("}") p.printf("}")
...@@ -212,6 +208,11 @@ func (p *printer) print(x reflect.Value) { ...@@ -212,6 +208,11 @@ func (p *printer) print(x reflect.Value) {
p.printf("}") p.printf("}")
default: default:
p.printf("%v", x.Interface()) value := x.Interface()
// position values can be printed nicely if we have a file set
if pos, ok := value.(token.Pos); ok && p.fset != nil {
value = p.fset.Position(pos)
}
p.printf("%v", value)
} }
} }
...@@ -2,31 +2,31 @@ ...@@ -2,31 +2,31 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file implements scopes, the objects they contain, // This file implements scopes and the objects they contain.
// and object types.
package ast package ast
import (
"bytes"
"fmt"
"go/token"
)
// A Scope maintains the set of named language entities declared // A Scope maintains the set of named language entities declared
// in the scope and a link to the immediately surrounding (outer) // in the scope and a link to the immediately surrounding (outer)
// scope. // scope.
// //
type Scope struct { type Scope struct {
Outer *Scope Outer *Scope
Objects []*Object // in declaration order Objects map[string]*Object
// Implementation note: In some cases (struct fields,
// function parameters) we need the source order of
// variables. Thus for now, we store scope entries
// in a linear list. If scopes become very large
// (say, for packages), we may need to change this
// to avoid slow lookups.
} }
// NewScope creates a new scope nested in the outer scope. // NewScope creates a new scope nested in the outer scope.
func NewScope(outer *Scope) *Scope { func NewScope(outer *Scope) *Scope {
const n = 4 // initial scope capacity, must be > 0 const n = 4 // initial scope capacity
return &Scope{outer, make([]*Object, 0, n)} return &Scope{outer, make(map[string]*Object, n)}
} }
...@@ -34,73 +34,108 @@ func NewScope(outer *Scope) *Scope { ...@@ -34,73 +34,108 @@ func NewScope(outer *Scope) *Scope {
// found in scope s, otherwise it returns nil. Outer scopes // found in scope s, otherwise it returns nil. Outer scopes
// are ignored. // are ignored.
// //
// Lookup always returns nil if name is "_", even if the scope
// contains objects with that name.
//
func (s *Scope) Lookup(name string) *Object { func (s *Scope) Lookup(name string) *Object {
if name != "_" { return s.Objects[name]
for _, obj := range s.Objects {
if obj.Name == name {
return obj
}
}
}
return nil
} }
// Insert attempts to insert a named object into the scope s. // Insert attempts to insert a named object into the scope s.
// If the scope does not contain an object with that name yet // If the scope does not contain an object with that name yet,
// or if the object is named "_", Insert inserts the object // Insert inserts the object and returns it. Otherwise, Insert
// and returns it. Otherwise, Insert leaves the scope unchanged // leaves the scope unchanged and returns the object found in
// and returns the object found in the scope instead. // the scope instead.
// //
func (s *Scope) Insert(obj *Object) *Object { func (s *Scope) Insert(obj *Object) (alt *Object) {
alt := s.Lookup(obj.Name) if alt = s.Objects[obj.Name]; alt == nil {
if alt == nil { s.Objects[obj.Name] = obj
s.append(obj)
alt = obj alt = obj
} }
return alt return
} }
func (s *Scope) append(obj *Object) { // Debugging support
s.Objects = append(s.Objects, obj) func (s *Scope) String() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "scope %p {", s)
if s != nil && len(s.Objects) > 0 {
fmt.Fprintln(&buf)
for _, obj := range s.Objects {
fmt.Fprintf(&buf, "\t%s %s\n", obj.Kind, obj.Name)
}
}
fmt.Fprintf(&buf, "}\n")
return buf.String()
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Objects // Objects
// An Object describes a language entity such as a package, // An Object describes a named language entity such as a package,
// constant, type, variable, or function (incl. methods). // constant, type, variable, function (incl. methods), or label.
// //
type Object struct { type Object struct {
Kind Kind Kind ObjKind
Name string // declared name Name string // declared name
Type *Type Decl interface{} // corresponding Field, XxxSpec, FuncDecl, or LabeledStmt; or nil
Decl interface{} // corresponding Field, XxxSpec or FuncDecl Type interface{} // place holder for type information; may be nil
N int // value of iota for this declaration
} }
// NewObj creates a new object of a given kind and name. // NewObj creates a new object of a given kind and name.
func NewObj(kind Kind, name string) *Object { func NewObj(kind ObjKind, name string) *Object {
return &Object{Kind: kind, Name: name} return &Object{Kind: kind, Name: name}
} }
// Kind describes what an object represents. // Pos computes the source position of the declaration of an object name.
type Kind int // The result may be an invalid position if it cannot be computed
// (obj.Decl may be nil or not correct).
func (obj *Object) Pos() token.Pos {
name := obj.Name
switch d := obj.Decl.(type) {
case *Field:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *ValueSpec:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *TypeSpec:
if d.Name.Name == name {
return d.Name.Pos()
}
case *FuncDecl:
if d.Name.Name == name {
return d.Name.Pos()
}
case *LabeledStmt:
if d.Label.Name == name {
return d.Label.Pos()
}
}
return token.NoPos
}
// ObKind describes what an object represents.
type ObjKind int
// The list of possible Object kinds. // The list of possible Object kinds.
const ( const (
Bad Kind = iota // for error handling Bad ObjKind = iota // for error handling
Pkg // package Pkg // package
Con // constant Con // constant
Typ // type Typ // type
Var // variable Var // variable
Fun // function or method Fun // function or method
Lbl // label
) )
...@@ -111,132 +146,8 @@ var objKindStrings = [...]string{ ...@@ -111,132 +146,8 @@ var objKindStrings = [...]string{
Typ: "type", Typ: "type",
Var: "var", Var: "var",
Fun: "func", Fun: "func",
Lbl: "label",
} }
func (kind Kind) String() string { return objKindStrings[kind] } func (kind ObjKind) String() string { return objKindStrings[kind] }
// IsExported returns whether obj is exported.
func (obj *Object) IsExported() bool { return IsExported(obj.Name) }
// ----------------------------------------------------------------------------
// Types
// A Type represents a Go type.
type Type struct {
Form Form
Obj *Object // corresponding type name, or nil
Scope *Scope // fields and methods, always present
N uint // basic type id, array length, number of function results, or channel direction
Key, Elt *Type // map key and array, pointer, slice, map or channel element
Params *Scope // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
Expr Expr // corresponding AST expression
}
// NewType creates a new type of a given form.
func NewType(form Form) *Type {
return &Type{Form: form, Scope: NewScope(nil)}
}
// Form describes the form of a type.
type Form int
// The list of possible type forms.
const (
BadType Form = iota // for error handling
Unresolved // type not fully setup
Basic
Array
Struct
Pointer
Function
Method
Interface
Slice
Map
Channel
Tuple
)
var formStrings = [...]string{
BadType: "badType",
Unresolved: "unresolved",
Basic: "basic",
Array: "array",
Struct: "struct",
Pointer: "pointer",
Function: "function",
Method: "method",
Interface: "interface",
Slice: "slice",
Map: "map",
Channel: "channel",
Tuple: "tuple",
}
func (form Form) String() string { return formStrings[form] }
// The list of basic type id's.
const (
Bool = iota
Byte
Uint
Int
Float
Complex
Uintptr
String
Uint8
Uint16
Uint32
Uint64
Int8
Int16
Int32
Int64
Float32
Float64
Complex64
Complex128
// TODO(gri) ideal types are missing
)
var BasicTypes = map[uint]string{
Bool: "bool",
Byte: "byte",
Uint: "uint",
Int: "int",
Float: "float",
Complex: "complex",
Uintptr: "uintptr",
String: "string",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
}
...@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) { ...@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) {
} }
case *CaseClause: case *CaseClause:
walkExprList(v, n.Values) walkExprList(v, n.List)
walkStmtList(v, n.Body) walkStmtList(v, n.Body)
case *SwitchStmt: case *SwitchStmt:
...@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) { ...@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) {
} }
Walk(v, n.Body) Walk(v, n.Body)
case *TypeCaseClause:
for _, x := range n.Types {
Walk(v, x)
}
walkStmtList(v, n.Body)
case *TypeSwitchStmt: case *TypeSwitchStmt:
if n.Init != nil { if n.Init != nil {
Walk(v, n.Init) Walk(v, n.Init)
......
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
pathutil "path" "path/filepath"
) )
...@@ -198,7 +198,7 @@ func ParseDir(fset *token.FileSet, path string, filter func(*os.FileInfo) bool, ...@@ -198,7 +198,7 @@ func ParseDir(fset *token.FileSet, path string, filter func(*os.FileInfo) bool,
for i := 0; i < len(list); i++ { for i := 0; i < len(list); i++ {
d := &list[i] d := &list[i]
if filter == nil || filter(d) { if filter == nil || filter(d) {
filenames[n] = pathutil.Join(path, d.Name) filenames[n] = filepath.Join(path, d.Name)
n++ n++
} }
} }
......
...@@ -21,6 +21,7 @@ var illegalInputs = []interface{}{ ...@@ -21,6 +21,7 @@ var illegalInputs = []interface{}{
`package p; func f() { if /* should have condition */ {} };`, `package p; func f() { if /* should have condition */ {} };`,
`package p; func f() { if ; /* should have condition */ {} };`, `package p; func f() { if ; /* should have condition */ {} };`,
`package p; func f() { if f(); /* should have condition */ {} };`, `package p; func f() { if f(); /* should have condition */ {} };`,
`package p; const c; /* should have constant value */`,
} }
...@@ -73,7 +74,7 @@ var validFiles = []string{ ...@@ -73,7 +74,7 @@ var validFiles = []string{
func TestParse3(t *testing.T) { func TestParse3(t *testing.T) {
for _, filename := range validFiles { for _, filename := range validFiles {
_, err := ParseFile(fset, filename, nil, 0) _, err := ParseFile(fset, filename, nil, DeclarationErrors)
if err != nil { if err != nil {
t.Errorf("ParseFile(%s): %v", filename, err) t.Errorf("ParseFile(%s): %v", filename, err)
} }
......
...@@ -12,7 +12,7 @@ import ( ...@@ -12,7 +12,7 @@ import (
"go/token" "go/token"
"io" "io"
"os" "os"
"path" "path/filepath"
"runtime" "runtime"
"tabwriter" "tabwriter"
) )
...@@ -94,22 +94,23 @@ type printer struct { ...@@ -94,22 +94,23 @@ type printer struct {
// written using writeItem. // written using writeItem.
last token.Position last token.Position
// HTML support
lastTaggedLine int // last line for which a line tag was written
// The list of all source comments, in order of appearance. // The list of all source comments, in order of appearance.
comments []*ast.CommentGroup // may be nil comments []*ast.CommentGroup // may be nil
cindex int // current comment index cindex int // current comment index
useNodeComments bool // if not set, ignore lead and line comments of nodes useNodeComments bool // if not set, ignore lead and line comments of nodes
// Cache of already computed node sizes.
nodeSizes map[ast.Node]int
} }
func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet) { func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) {
p.output = output p.output = output
p.Config = *cfg p.Config = *cfg
p.fset = fset p.fset = fset
p.errors = make(chan os.Error) p.errors = make(chan os.Error)
p.buffer = make([]whiteSpace, 0, 16) // whitespace sequences are short p.buffer = make([]whiteSpace, 0, 16) // whitespace sequences are short
p.nodeSizes = nodeSizes
} }
...@@ -244,7 +245,7 @@ func (p *printer) writeItem(pos token.Position, data []byte) { ...@@ -244,7 +245,7 @@ func (p *printer) writeItem(pos token.Position, data []byte) {
} }
if debug { if debug {
// do not update p.pos - use write0 // do not update p.pos - use write0
_, filename := path.Split(pos.Filename) _, filename := filepath.Split(pos.Filename)
p.write0([]byte(fmt.Sprintf("[%s:%d:%d]", filename, pos.Line, pos.Column))) p.write0([]byte(fmt.Sprintf("[%s:%d:%d]", filename, pos.Line, pos.Column)))
} }
p.write(data) p.write(data)
...@@ -994,13 +995,8 @@ type Config struct { ...@@ -994,13 +995,8 @@ type Config struct {
} }
// Fprint "pretty-prints" an AST node to output and returns the number // fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
// of bytes written and an error (if any) for a given configuration cfg. func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node interface{}, nodeSizes map[ast.Node]int) (int, os.Error) {
// Position information is interpreted relative to the file set fset.
// The node type must be *ast.File, or assignment-compatible to ast.Expr,
// ast.Decl, ast.Spec, or ast.Stmt.
//
func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
// redirect output through a trimmer to eliminate trailing whitespace // redirect output through a trimmer to eliminate trailing whitespace
// (Input to a tabwriter must be untrimmed since trailing tabs provide // (Input to a tabwriter must be untrimmed since trailing tabs provide
// formatting information. The tabwriter could provide trimming // formatting information. The tabwriter could provide trimming
...@@ -1029,7 +1025,7 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{ ...@@ -1029,7 +1025,7 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
// setup printer and print node // setup printer and print node
var p printer var p printer
p.init(output, cfg, fset) p.init(output, cfg, fset, nodeSizes)
go func() { go func() {
switch n := node.(type) { switch n := node.(type) {
case ast.Expr: case ast.Expr:
...@@ -1076,6 +1072,17 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{ ...@@ -1076,6 +1072,17 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
} }
// Fprint "pretty-prints" an AST node to output and returns the number
// of bytes written and an error (if any) for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *ast.File, or assignment-compatible to ast.Expr,
// ast.Decl, ast.Spec, or ast.Stmt.
//
func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
return cfg.fprint(output, fset, node, make(map[ast.Node]int))
}
// Fprint "pretty-prints" an AST node to output. // Fprint "pretty-prints" an AST node to output.
// It calls Config.Fprint with default settings. // It calls Config.Fprint with default settings.
// //
......
...@@ -11,8 +11,9 @@ import ( ...@@ -11,8 +11,9 @@ import (
"go/ast" "go/ast"
"go/parser" "go/parser"
"go/token" "go/token"
"path" "path/filepath"
"testing" "testing"
"time"
) )
...@@ -45,7 +46,7 @@ const ( ...@@ -45,7 +46,7 @@ const (
) )
func check(t *testing.T, source, golden string, mode checkMode) { func runcheck(t *testing.T, source, golden string, mode checkMode) {
// parse source // parse source
prog, err := parser.ParseFile(fset, source, nil, parser.ParseComments) prog, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
if err != nil { if err != nil {
...@@ -109,6 +110,32 @@ func check(t *testing.T, source, golden string, mode checkMode) { ...@@ -109,6 +110,32 @@ func check(t *testing.T, source, golden string, mode checkMode) {
} }
func check(t *testing.T, source, golden string, mode checkMode) {
// start a timer to produce a time-out signal
tc := make(chan int)
go func() {
time.Sleep(20e9) // plenty of a safety margin, even for very slow machines
tc <- 0
}()
// run the test
cc := make(chan int)
go func() {
runcheck(t, source, golden, mode)
cc <- 0
}()
// wait for the first finisher
select {
case <-tc:
// test running past time out
t.Errorf("%s: running too slowly", source)
case <-cc:
// test finished within alloted time margin
}
}
type entry struct { type entry struct {
source, golden string source, golden string
mode checkMode mode checkMode
...@@ -124,13 +151,14 @@ var data = []entry{ ...@@ -124,13 +151,14 @@ var data = []entry{
{"expressions.input", "expressions.raw", rawFormat}, {"expressions.input", "expressions.raw", rawFormat},
{"declarations.input", "declarations.golden", 0}, {"declarations.input", "declarations.golden", 0},
{"statements.input", "statements.golden", 0}, {"statements.input", "statements.golden", 0},
{"slow.input", "slow.golden", 0},
} }
func TestFiles(t *testing.T) { func TestFiles(t *testing.T) {
for _, e := range data { for _, e := range data {
source := path.Join(dataDir, e.source) source := filepath.Join(dataDir, e.source)
golden := path.Join(dataDir, e.golden) golden := filepath.Join(dataDir, e.golden)
check(t, source, golden, e.mode) check(t, source, golden, e.mode)
// TODO(gri) check that golden is idempotent // TODO(gri) check that golden is idempotent
//check(t, golden, golden, e.mode); //check(t, golden, golden, e.mode);
......
...@@ -224,11 +224,7 @@ func _() { ...@@ -224,11 +224,7 @@ func _() {
_ = struct{ x int }{0} _ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2} _ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0} _ = struct{ int }{0}
_ = struct { _ = struct{ s struct{ int } }{struct{ int }{0}}
s struct {
int
}
}{struct{ int }{0}} // compositeLit context not propagated => multiLine result
} }
......
...@@ -224,7 +224,7 @@ func _() { ...@@ -224,7 +224,7 @@ func _() {
_ = struct{ x int }{0} _ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2} _ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0} _ = struct{ int }{0}
_ = struct{ s struct { int } }{struct{ int}{0}} // compositeLit context not propagated => multiLine result _ = struct{ s struct { int } }{struct{ int}{0} }
} }
......
...@@ -224,11 +224,7 @@ func _() { ...@@ -224,11 +224,7 @@ func _() {
_ = struct{ x int }{0} _ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2} _ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0} _ = struct{ int }{0}
_ = struct { _ = struct{ s struct{ int } }{struct{ int }{0}}
s struct {
int
}
}{struct{ int }{0}} // compositeLit context not propagated => multiLine result
} }
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deepequal_test
import (
"testing"
"google3/spam/archer/frontend/deepequal"
)
func TestTwoNilValues(t *testing.T) {
if err := deepequal.Check(nil, nil); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
type Foo struct {
bar *Bar
bang *Bar
}
type Bar struct {
baz *Baz
foo []*Foo
}
type Baz struct {
entries map[int]interface{}
whatever string
}
func newFoo() *Foo {
return &Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
42: &Foo{},
21: &Bar{},
11: &Baz{whatever: "it's just a test"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
43: &Foo{},
22: &Bar{},
13: &Baz{whatever: "this is nuts"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
61: &Foo{},
71: &Bar{},
11: &Baz{whatever: "no, it's Go"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
0: &Foo{},
-2: &Bar{},
-11: &Baz{whatever: "we need to go deeper"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
-2: &Foo{},
-5: &Bar{},
-7: &Baz{whatever: "are you serious?"}}}},
bang: &Bar{foo: []*Foo{}}},
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
-100: &Foo{},
50: &Bar{},
20: &Baz{whatever: "na, not really ..."}}}},
bang: &Bar{foo: []*Foo{}}}}}}}}},
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
2: &Foo{},
1: &Bar{},
-1: &Baz{whatever: "... it's just a test."}}}},
bang: &Bar{foo: []*Foo{}}}}}}}}}
}
func TestElaborate(t *testing.T) {
a := newFoo()
b := newFoo()
if err := deepequal.Check(a, b); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deepequal_test
import (
"testing"
"google3/spam/archer/frontend/deepequal"
)
func TestTwoNilValues(t *testing.T) {
if err := deepequal.Check(nil, nil); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
type Foo struct {
bar *Bar
bang *Bar
}
type Bar struct {
baz *Baz
foo []*Foo
}
type Baz struct {
entries map[int]interface{}
whatever string
}
func newFoo() (*Foo) {
return &Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
42: &Foo{},
21: &Bar{},
11: &Baz{ whatever: "it's just a test" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
43: &Foo{},
22: &Bar{},
13: &Baz{ whatever: "this is nuts" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
61: &Foo{},
71: &Bar{},
11: &Baz{ whatever: "no, it's Go" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
0: &Foo{},
-2: &Bar{},
-11: &Baz{ whatever: "we need to go deeper" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
-2: &Foo{},
-5: &Bar{},
-7: &Baz{ whatever: "are you serious?" }}}},
bang: &Bar{foo: []*Foo{}}},
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
-100: &Foo{},
50: &Bar{},
20: &Baz{ whatever: "na, not really ..." }}}},
bang: &Bar{foo: []*Foo{}}}}}}}}},
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
2: &Foo{},
1: &Bar{},
-1: &Baz{ whatever: "... it's just a test." }}}},
bang: &Bar{foo: []*Foo{}}}}}}}}}
}
func TestElaborate(t *testing.T) {
a := newFoo()
b := newFoo()
if err := deepequal.Check(a, b); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
...@@ -23,7 +23,7 @@ package scanner ...@@ -23,7 +23,7 @@ package scanner
import ( import (
"bytes" "bytes"
"go/token" "go/token"
"path" "path/filepath"
"strconv" "strconv"
"unicode" "unicode"
"utf8" "utf8"
...@@ -118,7 +118,7 @@ func (S *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode uint ...@@ -118,7 +118,7 @@ func (S *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode uint
panic("file size does not match src len") panic("file size does not match src len")
} }
S.file = file S.file = file
S.dir, _ = path.Split(file.Name()) S.dir, _ = filepath.Split(file.Name())
S.src = src S.src = src
S.err = err S.err = err
S.mode = mode S.mode = mode
...@@ -177,13 +177,13 @@ var prefix = []byte("//line ") ...@@ -177,13 +177,13 @@ var prefix = []byte("//line ")
func (S *Scanner) interpretLineComment(text []byte) { func (S *Scanner) interpretLineComment(text []byte) {
if bytes.HasPrefix(text, prefix) { if bytes.HasPrefix(text, prefix) {
// get filename and line number, if any // get filename and line number, if any
if i := bytes.Index(text, []byte{':'}); i > 0 { if i := bytes.LastIndex(text, []byte{':'}); i > 0 {
if line, err := strconv.Atoi(string(text[i+1:])); err == nil && line > 0 { if line, err := strconv.Atoi(string(text[i+1:])); err == nil && line > 0 {
// valid //line filename:line comment; // valid //line filename:line comment;
filename := path.Clean(string(text[len(prefix):i])) filename := filepath.Clean(string(text[len(prefix):i]))
if filename[0] != '/' { if !filepath.IsAbs(filename) {
// make filename relative to current directory // make filename relative to current directory
filename = path.Join(S.dir, filename) filename = filepath.Join(S.dir, filename)
} }
// update scanner position // update scanner position
S.file.AddLineInfo(S.lineOffset, filename, line-1) // -1 since comment applies to next line S.file.AddLineInfo(S.lineOffset, filename, line-1) // -1 since comment applies to next line
......
...@@ -7,6 +7,8 @@ package scanner ...@@ -7,6 +7,8 @@ package scanner
import ( import (
"go/token" "go/token"
"os" "os"
"path/filepath"
"runtime"
"testing" "testing"
) )
...@@ -443,32 +445,41 @@ func TestSemis(t *testing.T) { ...@@ -443,32 +445,41 @@ func TestSemis(t *testing.T) {
} }
} }
type segment struct {
var segments = []struct {
srcline string // a line of source text srcline string // a line of source text
filename string // filename for current token filename string // filename for current token
line int // line number for current token line int // line number for current token
}{ }
var segments = []segment{
// exactly one token per line since the test consumes one token per segment // exactly one token per line since the test consumes one token per segment
{" line1", "dir/TestLineComments", 1}, {" line1", filepath.Join("dir", "TestLineComments"), 1},
{"\nline2", "dir/TestLineComments", 2}, {"\nline2", filepath.Join("dir", "TestLineComments"), 2},
{"\nline3 //line File1.go:100", "dir/TestLineComments", 3}, // bad line comment, ignored {"\nline3 //line File1.go:100", filepath.Join("dir", "TestLineComments"), 3}, // bad line comment, ignored
{"\nline4", "dir/TestLineComments", 4}, {"\nline4", filepath.Join("dir", "TestLineComments"), 4},
{"\n//line File1.go:100\n line100", "dir/File1.go", 100}, {"\n//line File1.go:100\n line100", filepath.Join("dir", "File1.go"), 100},
{"\n//line File2.go:200\n line200", "dir/File2.go", 200}, {"\n//line File2.go:200\n line200", filepath.Join("dir", "File2.go"), 200},
{"\n//line :1\n line1", "dir", 1}, {"\n//line :1\n line1", "dir", 1},
{"\n//line foo:42\n line42", "dir/foo", 42}, {"\n//line foo:42\n line42", filepath.Join("dir", "foo"), 42},
{"\n //line foo:42\n line44", "dir/foo", 44}, // bad line comment, ignored {"\n //line foo:42\n line44", filepath.Join("dir", "foo"), 44}, // bad line comment, ignored
{"\n//line foo 42\n line46", "dir/foo", 46}, // bad line comment, ignored {"\n//line foo 42\n line46", filepath.Join("dir", "foo"), 46}, // bad line comment, ignored
{"\n//line foo:42 extra text\n line48", "dir/foo", 48}, // bad line comment, ignored {"\n//line foo:42 extra text\n line48", filepath.Join("dir", "foo"), 48}, // bad line comment, ignored
{"\n//line /bar:42\n line42", "/bar", 42}, {"\n//line /bar:42\n line42", string(filepath.Separator) + "bar", 42},
{"\n//line ./foo:42\n line42", "dir/foo", 42}, {"\n//line ./foo:42\n line42", filepath.Join("dir", "foo"), 42},
{"\n//line a/b/c/File1.go:100\n line100", "dir/a/b/c/File1.go", 100}, {"\n//line a/b/c/File1.go:100\n line100", filepath.Join("dir", "a", "b", "c", "File1.go"), 100},
}
var winsegments = []segment{
{"\n//line c:\\dir\\File1.go:100\n line100", "c:\\dir\\File1.go", 100},
} }
// Verify that comments of the form "//line filename:line" are interpreted correctly. // Verify that comments of the form "//line filename:line" are interpreted correctly.
func TestLineComments(t *testing.T) { func TestLineComments(t *testing.T) {
if runtime.GOOS == "windows" {
segments = append(segments, winsegments...)
}
// make source // make source
var src string var src string
for _, e := range segments { for _, e := range segments {
...@@ -477,7 +488,7 @@ func TestLineComments(t *testing.T) { ...@@ -477,7 +488,7 @@ func TestLineComments(t *testing.T) {
// verify scan // verify scan
var S Scanner var S Scanner
file := fset.AddFile("dir/TestLineComments", fset.Base(), len(src)) file := fset.AddFile(filepath.Join("dir", "TestLineComments"), fset.Base(), len(src))
S.Init(file, []byte(src), nil, 0) S.Init(file, []byte(src), nil, 0)
for _, s := range segments { for _, s := range segments {
p, _, lit := S.Scan() p, _, lit := S.Scan()
......
...@@ -2,15 +2,15 @@ ...@@ -2,15 +2,15 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file implements scope support functions. // DEPRECATED FILE - WILL GO AWAY EVENTUALLY.
//
// Scope handling is now done in go/parser.
// The functionality here is only present to
// keep the typechecker running for now.
package typechecker package typechecker
import ( import "go/ast"
"fmt"
"go/ast"
"go/token"
)
func (tc *typechecker) openScope() *ast.Scope { func (tc *typechecker) openScope() *ast.Scope {
...@@ -24,52 +24,25 @@ func (tc *typechecker) closeScope() { ...@@ -24,52 +24,25 @@ func (tc *typechecker) closeScope() {
} }
// objPos computes the source position of the declaration of an object name.
// Only required for error reporting, so doesn't have to be fast.
func objPos(obj *ast.Object) (pos token.Pos) {
switch d := obj.Decl.(type) {
case *ast.Field:
for _, n := range d.Names {
if n.Name == obj.Name {
return n.Pos()
}
}
case *ast.ValueSpec:
for _, n := range d.Names {
if n.Name == obj.Name {
return n.Pos()
}
}
case *ast.TypeSpec:
return d.Name.Pos()
case *ast.FuncDecl:
return d.Name.Pos()
}
if debug {
fmt.Printf("decl = %T\n", obj.Decl)
}
panic("unreachable")
}
// declInScope declares an object of a given kind and name in scope and sets the object's Decl and N fields. // declInScope declares an object of a given kind and name in scope and sets the object's Decl and N fields.
// It returns the newly allocated object. If an object with the same name already exists in scope, an error // It returns the newly allocated object. If an object with the same name already exists in scope, an error
// is reported and the object is not inserted. // is reported and the object is not inserted.
// (Objects with _ name are always inserted into a scope without errors, but they cannot be found.) func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object {
obj := ast.NewObj(kind, name.Name) obj := ast.NewObj(kind, name.Name)
obj.Decl = decl obj.Decl = decl
obj.N = n //obj.N = n
name.Obj = obj name.Obj = obj
if alt := scope.Insert(obj); alt != obj { if name.Name != "_" {
tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, objPos(alt)) if alt := scope.Insert(obj); alt != obj {
tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, tc.fset.Position(alt.Pos()).String())
}
} }
return obj return obj
} }
// decl is the same as declInScope(tc.topScope, ...) // decl is the same as declInScope(tc.topScope, ...)
func (tc *typechecker) decl(kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object { func (tc *typechecker) decl(kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
return tc.declInScope(tc.topScope, kind, name, decl, n) return tc.declInScope(tc.topScope, kind, name, decl, n)
} }
...@@ -91,7 +64,7 @@ func (tc *typechecker) find(name *ast.Ident) (obj *ast.Object) { ...@@ -91,7 +64,7 @@ func (tc *typechecker) find(name *ast.Ident) (obj *ast.Object) {
// findField returns the object with the given name if visible in the type's scope. // findField returns the object with the given name if visible in the type's scope.
// If no such object is found, an error is reported and a bad object is returned instead. // If no such object is found, an error is reported and a bad object is returned instead.
func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Object) { func (tc *typechecker) findField(typ *Type, name *ast.Ident) (obj *ast.Object) {
// TODO(gri) This is simplistic at the moment and ignores anonymous fields. // TODO(gri) This is simplistic at the moment and ignores anonymous fields.
obj = typ.Scope.Lookup(name.Name) obj = typ.Scope.Lookup(name.Name)
if obj == nil { if obj == nil {
...@@ -100,20 +73,3 @@ func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Objec ...@@ -100,20 +73,3 @@ func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Objec
} }
return return
} }
// printScope prints the objects in a scope.
func printScope(scope *ast.Scope) {
fmt.Printf("scope %p {", scope)
if scope != nil && len(scope.Objects) > 0 {
fmt.Println()
for _, obj := range scope.Objects {
form := "void"
if obj.Type != nil {
form = obj.Type.Form.String()
}
fmt.Printf("\t%s\t%s\n", obj.Name, form)
}
}
fmt.Printf("}\n")
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
package P1 package P1
const ( const (
c1 /* ERROR "missing initializer" */ c1 = 0
c2 int = 0 c2 int = 0
c3, c4 = 0 c3, c4 = 0
) )
...@@ -27,8 +27,11 @@ func (T) m1 /* ERROR "already declared" */ () {} ...@@ -27,8 +27,11 @@ func (T) m1 /* ERROR "already declared" */ () {}
func (x *T) m2(u, x /* ERROR "already declared" */ int) {} func (x *T) m2(u, x /* ERROR "already declared" */ int) {}
func (x *T) m3(a, b, c int) (u, x /* ERROR "already declared" */ int) {} func (x *T) m3(a, b, c int) (u, x /* ERROR "already declared" */ int) {}
func (T) _(x, x /* ERROR "already declared" */ int) {} // The following are disabled for now because the typechecker
func (T) _() (x, x /* ERROR "already declared" */ int) {} // in in the process of being rewritten and cannot handle them
// at the moment
//func (T) _(x, x /* "already declared" */ int) {}
//func (T) _() (x, x /* "already declared" */ int) {}
//func (PT) _() {} //func (PT) _() {}
......
...@@ -7,5 +7,5 @@ ...@@ -7,5 +7,5 @@
package P4 package P4
const ( const (
c0 /* ERROR "missing initializer" */ c0 = 0
) )
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package typechecker
import "go/ast"
// A Type represents a Go type.
type Type struct {
Form Form
Obj *ast.Object // corresponding type name, or nil
Scope *ast.Scope // fields and methods, always present
N uint // basic type id, array length, number of function results, or channel direction
Key, Elt *Type // map key and array, pointer, slice, map or channel element
Params *ast.Scope // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
Expr ast.Expr // corresponding AST expression
}
// NewType creates a new type of a given form.
func NewType(form Form) *Type {
return &Type{Form: form, Scope: ast.NewScope(nil)}
}
// Form describes the form of a type.
type Form int
// The list of possible type forms.
const (
BadType Form = iota // for error handling
Unresolved // type not fully setup
Basic
Array
Struct
Pointer
Function
Method
Interface
Slice
Map
Channel
Tuple
)
var formStrings = [...]string{
BadType: "badType",
Unresolved: "unresolved",
Basic: "basic",
Array: "array",
Struct: "struct",
Pointer: "pointer",
Function: "function",
Method: "method",
Interface: "interface",
Slice: "slice",
Map: "map",
Channel: "channel",
Tuple: "tuple",
}
func (form Form) String() string { return formStrings[form] }
// The list of basic type id's.
const (
Bool = iota
Byte
Uint
Int
Float
Complex
Uintptr
String
Uint8
Uint16
Uint32
Uint64
Int8
Int16
Int32
Int64
Float32
Float64
Complex64
Complex128
// TODO(gri) ideal types are missing
)
var BasicTypes = map[uint]string{
Bool: "bool",
Byte: "byte",
Uint: "uint",
Int: "int",
Float: "float",
Complex: "complex",
Uintptr: "uintptr",
String: "string",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
}
...@@ -65,6 +65,7 @@ type typechecker struct { ...@@ -65,6 +65,7 @@ type typechecker struct {
fset *token.FileSet fset *token.FileSet
scanner.ErrorVector scanner.ErrorVector
importer Importer importer Importer
globals []*ast.Object // list of global objects
topScope *ast.Scope // current top-most scope topScope *ast.Scope // current top-most scope
cyclemap map[*ast.Object]bool // for cycle detection cyclemap map[*ast.Object]bool // for cycle detection
iota int // current value of iota iota int // current value of iota
...@@ -94,7 +95,7 @@ phase 1: declare all global objects; also collect all function and method declar ...@@ -94,7 +95,7 @@ phase 1: declare all global objects; also collect all function and method declar
- report global double declarations - report global double declarations
phase 2: bind methods to their receiver base types phase 2: bind methods to their receiver base types
- received base types must be declared in the package, thus for - receiver base types must be declared in the package, thus for
each method a corresponding (unresolved) type must exist each method a corresponding (unresolved) type must exist
- report method double declarations and errors with base types - report method double declarations and errors with base types
...@@ -142,16 +143,16 @@ func (tc *typechecker) checkPackage(pkg *ast.Package) { ...@@ -142,16 +143,16 @@ func (tc *typechecker) checkPackage(pkg *ast.Package) {
} }
// phase 3: resolve all global objects // phase 3: resolve all global objects
// (note that objects with _ name are also in the scope)
tc.cyclemap = make(map[*ast.Object]bool) tc.cyclemap = make(map[*ast.Object]bool)
for _, obj := range tc.topScope.Objects { for _, obj := range tc.globals {
tc.resolve(obj) tc.resolve(obj)
} }
assert(len(tc.cyclemap) == 0) assert(len(tc.cyclemap) == 0)
// 4: sequentially typecheck function and method bodies // 4: sequentially typecheck function and method bodies
for _, f := range funcs { for _, f := range funcs {
tc.checkBlock(f.Body.List, f.Name.Obj.Type) ftype, _ := f.Name.Obj.Type.(*Type)
tc.checkBlock(f.Body.List, ftype)
} }
pkg.Scope = tc.topScope pkg.Scope = tc.topScope
...@@ -183,11 +184,11 @@ func (tc *typechecker) declGlobal(global ast.Decl) { ...@@ -183,11 +184,11 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
} }
} }
for _, name := range s.Names { for _, name := range s.Names {
tc.decl(ast.Con, name, s, iota) tc.globals = append(tc.globals, tc.decl(ast.Con, name, s, iota))
} }
case token.VAR: case token.VAR:
for _, name := range s.Names { for _, name := range s.Names {
tc.decl(ast.Var, name, s, 0) tc.globals = append(tc.globals, tc.decl(ast.Var, name, s, 0))
} }
default: default:
panic("unreachable") panic("unreachable")
...@@ -196,9 +197,10 @@ func (tc *typechecker) declGlobal(global ast.Decl) { ...@@ -196,9 +197,10 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
iota++ iota++
case *ast.TypeSpec: case *ast.TypeSpec:
obj := tc.decl(ast.Typ, s.Name, s, 0) obj := tc.decl(ast.Typ, s.Name, s, 0)
tc.globals = append(tc.globals, obj)
// give all type objects an unresolved type so // give all type objects an unresolved type so
// that we can collect methods in the type scope // that we can collect methods in the type scope
typ := ast.NewType(ast.Unresolved) typ := NewType(Unresolved)
obj.Type = typ obj.Type = typ
typ.Obj = obj typ.Obj = obj
default: default:
...@@ -208,7 +210,7 @@ func (tc *typechecker) declGlobal(global ast.Decl) { ...@@ -208,7 +210,7 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
case *ast.FuncDecl: case *ast.FuncDecl:
if d.Recv == nil { if d.Recv == nil {
tc.decl(ast.Fun, d.Name, d, 0) tc.globals = append(tc.globals, tc.decl(ast.Fun, d.Name, d, 0))
} }
default: default:
...@@ -239,8 +241,8 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) { ...@@ -239,8 +241,8 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
} else if obj.Kind != ast.Typ { } else if obj.Kind != ast.Typ {
tc.Errorf(name.Pos(), "invalid receiver: %s is not a type", name.Name) tc.Errorf(name.Pos(), "invalid receiver: %s is not a type", name.Name)
} else { } else {
typ := obj.Type typ := obj.Type.(*Type)
assert(typ.Form == ast.Unresolved) assert(typ.Form == Unresolved)
scope = typ.Scope scope = typ.Scope
} }
} }
...@@ -261,7 +263,7 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) { ...@@ -261,7 +263,7 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
func (tc *typechecker) resolve(obj *ast.Object) { func (tc *typechecker) resolve(obj *ast.Object) {
// check for declaration cycles // check for declaration cycles
if tc.cyclemap[obj] { if tc.cyclemap[obj] {
tc.Errorf(objPos(obj), "illegal cycle in declaration of %s", obj.Name) tc.Errorf(obj.Pos(), "illegal cycle in declaration of %s", obj.Name)
obj.Kind = ast.Bad obj.Kind = ast.Bad
return return
} }
...@@ -271,7 +273,7 @@ func (tc *typechecker) resolve(obj *ast.Object) { ...@@ -271,7 +273,7 @@ func (tc *typechecker) resolve(obj *ast.Object) {
}() }()
// resolve non-type objects // resolve non-type objects
typ := obj.Type typ, _ := obj.Type.(*Type)
if typ == nil { if typ == nil {
switch obj.Kind { switch obj.Kind {
case ast.Bad: case ast.Bad:
...@@ -282,12 +284,12 @@ func (tc *typechecker) resolve(obj *ast.Object) { ...@@ -282,12 +284,12 @@ func (tc *typechecker) resolve(obj *ast.Object) {
case ast.Var: case ast.Var:
tc.declVar(obj) tc.declVar(obj)
//obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false) obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false)
case ast.Fun: case ast.Fun:
obj.Type = ast.NewType(ast.Function) obj.Type = NewType(Function)
t := obj.Decl.(*ast.FuncDecl).Type t := obj.Decl.(*ast.FuncDecl).Type
tc.declSignature(obj.Type, nil, t.Params, t.Results) tc.declSignature(obj.Type.(*Type), nil, t.Params, t.Results)
default: default:
// type objects have non-nil types when resolve is called // type objects have non-nil types when resolve is called
...@@ -300,32 +302,34 @@ func (tc *typechecker) resolve(obj *ast.Object) { ...@@ -300,32 +302,34 @@ func (tc *typechecker) resolve(obj *ast.Object) {
} }
// resolve type objects // resolve type objects
if typ.Form == ast.Unresolved { if typ.Form == Unresolved {
tc.typeFor(typ, typ.Obj.Decl.(*ast.TypeSpec).Type, false) tc.typeFor(typ, typ.Obj.Decl.(*ast.TypeSpec).Type, false)
// provide types for all methods // provide types for all methods
for _, obj := range typ.Scope.Objects { for _, obj := range typ.Scope.Objects {
if obj.Kind == ast.Fun { if obj.Kind == ast.Fun {
assert(obj.Type == nil) assert(obj.Type == nil)
obj.Type = ast.NewType(ast.Method) obj.Type = NewType(Method)
f := obj.Decl.(*ast.FuncDecl) f := obj.Decl.(*ast.FuncDecl)
t := f.Type t := f.Type
tc.declSignature(obj.Type, f.Recv, t.Params, t.Results) tc.declSignature(obj.Type.(*Type), f.Recv, t.Params, t.Results)
} }
} }
} }
} }
func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *ast.Type) { func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *Type) {
tc.openScope() tc.openScope()
defer tc.closeScope() defer tc.closeScope()
// inject function/method parameters into block scope, if any // inject function/method parameters into block scope, if any
if ftype != nil { if ftype != nil {
for _, par := range ftype.Params.Objects { for _, par := range ftype.Params.Objects {
obj := tc.topScope.Insert(par) if par.Name != "_" {
assert(obj == par) // ftype has no double declarations obj := tc.topScope.Insert(par)
assert(obj == par) // ftype has no double declarations
}
} }
} }
...@@ -362,8 +366,8 @@ func (tc *typechecker) declFields(scope *ast.Scope, fields *ast.FieldList, ref b ...@@ -362,8 +366,8 @@ func (tc *typechecker) declFields(scope *ast.Scope, fields *ast.FieldList, ref b
} }
func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.FieldList) { func (tc *typechecker) declSignature(typ *Type, recv, params, results *ast.FieldList) {
assert((typ.Form == ast.Method) == (recv != nil)) assert((typ.Form == Method) == (recv != nil))
typ.Params = ast.NewScope(nil) typ.Params = ast.NewScope(nil)
tc.declFields(typ.Params, recv, true) tc.declFields(typ.Params, recv, true)
tc.declFields(typ.Params, params, true) tc.declFields(typ.Params, params, true)
...@@ -371,7 +375,7 @@ func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.F ...@@ -371,7 +375,7 @@ func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.F
} }
func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Type) { func (tc *typechecker) typeFor(def *Type, x ast.Expr, ref bool) (typ *Type) {
x = unparen(x) x = unparen(x)
// type name // type name
...@@ -381,10 +385,10 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty ...@@ -381,10 +385,10 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if obj.Kind != ast.Typ { if obj.Kind != ast.Typ {
tc.Errorf(t.Pos(), "%s is not a type", t.Name) tc.Errorf(t.Pos(), "%s is not a type", t.Name)
if def == nil { if def == nil {
typ = ast.NewType(ast.BadType) typ = NewType(BadType)
} else { } else {
typ = def typ = def
typ.Form = ast.BadType typ.Form = BadType
} }
typ.Expr = x typ.Expr = x
return return
...@@ -393,7 +397,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty ...@@ -393,7 +397,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if !ref { if !ref {
tc.resolve(obj) // check for cycles even if type resolved tc.resolve(obj) // check for cycles even if type resolved
} }
typ = obj.Type typ = obj.Type.(*Type)
if def != nil { if def != nil {
// new type declaration: copy type structure // new type declaration: copy type structure
...@@ -410,7 +414,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty ...@@ -410,7 +414,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
// type literal // type literal
typ = def typ = def
if typ == nil { if typ == nil {
typ = ast.NewType(ast.BadType) typ = NewType(BadType)
} }
typ.Expr = x typ.Expr = x
...@@ -419,42 +423,42 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty ...@@ -419,42 +423,42 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if debug { if debug {
fmt.Println("qualified identifier unimplemented") fmt.Println("qualified identifier unimplemented")
} }
typ.Form = ast.BadType typ.Form = BadType
case *ast.StarExpr: case *ast.StarExpr:
typ.Form = ast.Pointer typ.Form = Pointer
typ.Elt = tc.typeFor(nil, t.X, true) typ.Elt = tc.typeFor(nil, t.X, true)
case *ast.ArrayType: case *ast.ArrayType:
if t.Len != nil { if t.Len != nil {
typ.Form = ast.Array typ.Form = Array
// TODO(gri) compute the real length // TODO(gri) compute the real length
// (this may call resolve recursively) // (this may call resolve recursively)
(*typ).N = 42 (*typ).N = 42
} else { } else {
typ.Form = ast.Slice typ.Form = Slice
} }
typ.Elt = tc.typeFor(nil, t.Elt, t.Len == nil) typ.Elt = tc.typeFor(nil, t.Elt, t.Len == nil)
case *ast.StructType: case *ast.StructType:
typ.Form = ast.Struct typ.Form = Struct
tc.declFields(typ.Scope, t.Fields, false) tc.declFields(typ.Scope, t.Fields, false)
case *ast.FuncType: case *ast.FuncType:
typ.Form = ast.Function typ.Form = Function
tc.declSignature(typ, nil, t.Params, t.Results) tc.declSignature(typ, nil, t.Params, t.Results)
case *ast.InterfaceType: case *ast.InterfaceType:
typ.Form = ast.Interface typ.Form = Interface
tc.declFields(typ.Scope, t.Methods, true) tc.declFields(typ.Scope, t.Methods, true)
case *ast.MapType: case *ast.MapType:
typ.Form = ast.Map typ.Form = Map
typ.Key = tc.typeFor(nil, t.Key, true) typ.Key = tc.typeFor(nil, t.Key, true)
typ.Elt = tc.typeFor(nil, t.Value, true) typ.Elt = tc.typeFor(nil, t.Value, true)
case *ast.ChanType: case *ast.ChanType:
typ.Form = ast.Channel typ.Form = Channel
typ.N = uint(t.Dir) typ.N = uint(t.Dir)
typ.Elt = tc.typeFor(nil, t.Value, true) typ.Elt = tc.typeFor(nil, t.Value, true)
......
...@@ -93,7 +93,7 @@ func expectedErrors(t *testing.T, pkg *ast.Package) (list scanner.ErrorList) { ...@@ -93,7 +93,7 @@ func expectedErrors(t *testing.T, pkg *ast.Package) (list scanner.ErrorList) {
func testFilter(f *os.FileInfo) bool { func testFilter(f *os.FileInfo) bool {
return strings.HasSuffix(f.Name, ".go") && f.Name[0] != '.' return strings.HasSuffix(f.Name, ".src") && f.Name[0] != '.'
} }
......
...@@ -24,8 +24,8 @@ func init() { ...@@ -24,8 +24,8 @@ func init() {
Universe = ast.NewScope(nil) Universe = ast.NewScope(nil)
// basic types // basic types
for n, name := range ast.BasicTypes { for n, name := range BasicTypes {
typ := ast.NewType(ast.Basic) typ := NewType(Basic)
typ.N = n typ.N = n
obj := ast.NewObj(ast.Typ, name) obj := ast.NewObj(ast.Typ, name)
obj.Type = typ obj.Type = typ
......
...@@ -50,7 +50,7 @@ func testError(t *testing.T) { ...@@ -50,7 +50,7 @@ func testError(t *testing.T) {
func TestUintCodec(t *testing.T) { func TestUintCodec(t *testing.T) {
defer testError(t) defer testError(t)
b := new(bytes.Buffer) b := new(bytes.Buffer)
encState := newEncoderState(nil, b) encState := newEncoderState(b)
for _, tt := range encodeT { for _, tt := range encodeT {
b.Reset() b.Reset()
encState.encodeUint(tt.x) encState.encodeUint(tt.x)
...@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) { ...@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) {
t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes()) t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
} }
} }
decState := newDecodeState(nil, b) decState := newDecodeState(b)
for u := uint64(0); ; u = (u + 1) * 7 { for u := uint64(0); ; u = (u + 1) * 7 {
b.Reset() b.Reset()
encState.encodeUint(u) encState.encodeUint(u)
...@@ -75,9 +75,9 @@ func TestUintCodec(t *testing.T) { ...@@ -75,9 +75,9 @@ func TestUintCodec(t *testing.T) {
func verifyInt(i int64, t *testing.T) { func verifyInt(i int64, t *testing.T) {
defer testError(t) defer testError(t)
var b = new(bytes.Buffer) var b = new(bytes.Buffer)
encState := newEncoderState(nil, b) encState := newEncoderState(b)
encState.encodeInt(i) encState.encodeInt(i)
decState := newDecodeState(nil, b) decState := newDecodeState(b)
decState.buf = make([]byte, 8) decState.buf = make([]byte, 8)
j := decState.decodeInt() j := decState.decodeInt()
if i != j { if i != j {
...@@ -111,9 +111,16 @@ var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40} ...@@ -111,9 +111,16 @@ var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40}
// The result of encoding "hello" with field number 7 // The result of encoding "hello" with field number 7
var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'} var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'}
func newencoderState(b *bytes.Buffer) *encoderState { func newDecodeState(buf *bytes.Buffer) *decoderState {
d := new(decoderState)
d.b = buf
d.buf = make([]byte, uint64Size)
return d
}
func newEncoderState(b *bytes.Buffer) *encoderState {
b.Reset() b.Reset()
state := newEncoderState(nil, b) state := &encoderState{enc: nil, b: b}
state.fieldnum = -1 state.fieldnum = -1
return state return state
} }
...@@ -127,7 +134,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -127,7 +134,7 @@ func TestScalarEncInstructions(t *testing.T) {
{ {
data := struct{ a bool }{true} data := struct{ a bool }{true}
instr := &encInstr{encBool, 6, 0, 0} instr := &encInstr{encBool, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(boolResult, b.Bytes()) { if !bytes.Equal(boolResult, b.Bytes()) {
t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Bytes()) t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Bytes())
...@@ -139,7 +146,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -139,7 +146,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a int }{17} data := struct{ a int }{17}
instr := &encInstr{encInt, 6, 0, 0} instr := &encInstr{encInt, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) { if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Bytes()) t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Bytes())
...@@ -151,7 +158,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -151,7 +158,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a uint }{17} data := struct{ a uint }{17}
instr := &encInstr{encUint, 6, 0, 0} instr := &encInstr{encUint, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) { if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Bytes()) t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Bytes())
...@@ -163,7 +170,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -163,7 +170,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a int8 }{17} data := struct{ a int8 }{17}
instr := &encInstr{encInt8, 6, 0, 0} instr := &encInstr{encInt8, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) { if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Bytes()) t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Bytes())
...@@ -175,7 +182,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -175,7 +182,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a uint8 }{17} data := struct{ a uint8 }{17}
instr := &encInstr{encUint8, 6, 0, 0} instr := &encInstr{encUint8, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) { if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
...@@ -187,7 +194,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -187,7 +194,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a int16 }{17} data := struct{ a int16 }{17}
instr := &encInstr{encInt16, 6, 0, 0} instr := &encInstr{encInt16, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) { if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Bytes()) t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Bytes())
...@@ -199,7 +206,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -199,7 +206,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a uint16 }{17} data := struct{ a uint16 }{17}
instr := &encInstr{encUint16, 6, 0, 0} instr := &encInstr{encUint16, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) { if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
...@@ -211,7 +218,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -211,7 +218,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a int32 }{17} data := struct{ a int32 }{17}
instr := &encInstr{encInt32, 6, 0, 0} instr := &encInstr{encInt32, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) { if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Bytes()) t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Bytes())
...@@ -223,7 +230,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -223,7 +230,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a uint32 }{17} data := struct{ a uint32 }{17}
instr := &encInstr{encUint32, 6, 0, 0} instr := &encInstr{encUint32, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) { if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
...@@ -235,7 +242,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -235,7 +242,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a int64 }{17} data := struct{ a int64 }{17}
instr := &encInstr{encInt64, 6, 0, 0} instr := &encInstr{encInt64, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) { if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Bytes()) t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Bytes())
...@@ -247,7 +254,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -247,7 +254,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a uint64 }{17} data := struct{ a uint64 }{17}
instr := &encInstr{encUint64, 6, 0, 0} instr := &encInstr{encUint64, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) { if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Bytes()) t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
...@@ -259,7 +266,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -259,7 +266,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a float32 }{17} data := struct{ a float32 }{17}
instr := &encInstr{encFloat32, 6, 0, 0} instr := &encInstr{encFloat32, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(floatResult, b.Bytes()) { if !bytes.Equal(floatResult, b.Bytes()) {
t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Bytes()) t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Bytes())
...@@ -271,7 +278,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -271,7 +278,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a float64 }{17} data := struct{ a float64 }{17}
instr := &encInstr{encFloat64, 6, 0, 0} instr := &encInstr{encFloat64, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(floatResult, b.Bytes()) { if !bytes.Equal(floatResult, b.Bytes()) {
t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Bytes()) t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Bytes())
...@@ -283,7 +290,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -283,7 +290,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a []byte }{[]byte("hello")} data := struct{ a []byte }{[]byte("hello")}
instr := &encInstr{encUint8Array, 6, 0, 0} instr := &encInstr{encUint8Array, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(bytesResult, b.Bytes()) { if !bytes.Equal(bytesResult, b.Bytes()) {
t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Bytes()) t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Bytes())
...@@ -295,7 +302,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -295,7 +302,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset() b.Reset()
data := struct{ a string }{"hello"} data := struct{ a string }{"hello"}
instr := &encInstr{encString, 6, 0, 0} instr := &encInstr{encString, 6, 0, 0}
state := newencoderState(b) state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data)) instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(bytesResult, b.Bytes()) { if !bytes.Equal(bytesResult, b.Bytes()) {
t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Bytes()) t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Bytes())
...@@ -303,7 +310,7 @@ func TestScalarEncInstructions(t *testing.T) { ...@@ -303,7 +310,7 @@ func TestScalarEncInstructions(t *testing.T) {
} }
} }
func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) {
defer testError(t) defer testError(t)
v := int(state.decodeUint()) v := int(state.decodeUint())
if v+state.fieldnum != 6 { if v+state.fieldnum != 6 {
...@@ -313,9 +320,9 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un ...@@ -313,9 +320,9 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
state.fieldnum = 6 state.fieldnum = 6
} }
func newDecodeStateFromData(data []byte) *decodeState { func newDecodeStateFromData(data []byte) *decoderState {
b := bytes.NewBuffer(data) b := bytes.NewBuffer(data)
state := newDecodeState(nil, b) state := newDecodeState(b)
state.fieldnum = -1 state.fieldnum = -1
return state return state
} }
...@@ -997,9 +1004,9 @@ func TestInvalidField(t *testing.T) { ...@@ -997,9 +1004,9 @@ func TestInvalidField(t *testing.T) {
var bad0 Bad0 var bad0 Bad0
bad0.CH = make(chan int) bad0.CH = make(chan int)
b := new(bytes.Buffer) b := new(bytes.Buffer)
var nilEncoder *Encoder dummyEncoder := new(Encoder) // sufficient for this purpose.
err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil { if err := dummyEncoder.err; err == nil {
t.Error("expected error; got none") t.Error("expected error; got none")
} else if strings.Index(err.String(), "type") < 0 { } else if strings.Index(err.String(), "type") < 0 {
t.Error("expected type error; got", err) t.Error("expected type error; got", err)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package gob package gob
import ( import (
"bufio"
"bytes" "bytes"
"io" "io"
"os" "os"
...@@ -21,7 +22,7 @@ type Decoder struct { ...@@ -21,7 +22,7 @@ type Decoder struct {
wireType map[typeId]*wireType // map from remote ID to local description wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects ignorerCache map[typeId]**decEngine // ditto for ignored objects
countState *decodeState // reads counts from wire freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages countBuf []byte // used for decoding integers while parsing messages
tmp []byte // temporary storage for i/o; saves reallocating tmp []byte // temporary storage for i/o; saves reallocating
err os.Error err os.Error
...@@ -30,7 +31,7 @@ type Decoder struct { ...@@ -30,7 +31,7 @@ type Decoder struct {
// NewDecoder returns a new decoder that reads from the io.Reader. // NewDecoder returns a new decoder that reads from the io.Reader.
func NewDecoder(r io.Reader) *Decoder { func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder) dec := new(Decoder)
dec.r = r dec.r = bufio.NewReader(r)
dec.wireType = make(map[typeId]*wireType) dec.wireType = make(map[typeId]*wireType)
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine) dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine) dec.ignorerCache = make(map[typeId]**decEngine)
...@@ -49,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) { ...@@ -49,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) {
// Type: // Type:
wire := new(wireType) wire := new(wireType)
dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire)) dec.decodeValue(tWireType, reflect.NewValue(wire))
if dec.err != nil { if dec.err != nil {
return return
} }
...@@ -184,7 +185,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { ...@@ -184,7 +185,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
dec.err = nil dec.err = nil
id := dec.decodeTypeSequence(false) id := dec.decodeTypeSequence(false)
if dec.err == nil { if dec.err == nil {
dec.err = dec.decodeValue(id, value) dec.decodeValue(id, value)
} }
return dec.err return dec.err
} }
......
...@@ -19,7 +19,9 @@ type Encoder struct { ...@@ -19,7 +19,9 @@ type Encoder struct {
w []io.Writer // where to send the data w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent sent map[reflect.Type]typeId // which types we've already sent
countState *encoderState // stage for writing counts countState *encoderState // stage for writing counts
freeList *encoderState // list of free encoderStates; avoids reallocation
buf []byte // for collecting the output. buf []byte // for collecting the output.
byteBuf bytes.Buffer // buffer for top-level encoderState
err os.Error err os.Error
} }
...@@ -28,7 +30,7 @@ func NewEncoder(w io.Writer) *Encoder { ...@@ -28,7 +30,7 @@ func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder) enc := new(Encoder)
enc.w = []io.Writer{w} enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId) enc.sent = make(map[reflect.Type]typeId)
enc.countState = newEncoderState(enc, new(bytes.Buffer)) enc.countState = enc.newEncoderState(new(bytes.Buffer))
return enc return enc
} }
...@@ -78,12 +80,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { ...@@ -78,12 +80,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
} }
} }
// sendActualType sends the requested type, without further investigation, unless
// it's been sent before.
func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
if _, alreadySent := enc.sent[actual]; alreadySent {
return false
}
typeLock.Lock()
info, err := getTypeInfo(ut)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type, both what the user gave us and the base type.
enc.sent[ut.base] = info.id
if ut.user != ut.base {
enc.sent[ut.user] = info.id
}
// Now send the inner types
switch st := actual.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
}
// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type.
ut := userType(origt) ut := userType(origt)
rt := ut.base if ut.isGobEncoder {
// The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that this exact type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.user)
}
switch rt := rt.(type) { // It's a concrete value, so drill down to the base type.
switch rt := ut.base.(type) {
default: default:
// Basic types and interfaces do not need to be described. // Basic types and interfaces do not need to be described.
return return
...@@ -109,43 +156,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ ...@@ -109,43 +156,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
return return
} }
// Have we already sent this type? This time we ask about the base type. return enc.sendActualType(w, state, ut, ut.base)
if _, alreadySent := enc.sent[rt]; alreadySent {
return
}
// Need to send it.
typeLock.Lock()
info, err := getTypeInfo(rt)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type.
enc.sent[rt] = info.id
// Remember we've sent the top-level, possibly indirect type too.
enc.sent[origt] = info.id
// Now send the inner types
switch st := rt.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
} }
// Encode transmits the data item represented by the empty interface value, // Encode transmits the data item represented by the empty interface value,
...@@ -159,11 +170,14 @@ func (enc *Encoder) Encode(e interface{}) os.Error { ...@@ -159,11 +170,14 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// sent. // sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side. // Make sure the type is known to the other side.
// First, have we already sent this (base) type? // First, have we already sent this type?
base := ut.base rt := ut.base
if _, alreadySent := enc.sent[base]; !alreadySent { if ut.isGobEncoder {
rt = ut.user
}
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it. // No, so send it.
sent := enc.sendType(w, state, base) sent := enc.sendType(w, state, rt)
if enc.err != nil { if enc.err != nil {
return return
} }
...@@ -172,13 +186,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use ...@@ -172,13 +186,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
// need to send the type info but we do need to update enc.sent. // need to send the type info but we do need to update enc.sent.
if !sent { if !sent {
typeLock.Lock() typeLock.Lock()
info, err := getTypeInfo(base) info, err := getTypeInfo(ut)
typeLock.Unlock() typeLock.Unlock()
if err != nil { if err != nil {
enc.setError(err) enc.setError(err)
return return
} }
enc.sent[base] = info.id enc.sent[rt] = info.id
} }
} }
} }
...@@ -206,7 +220,8 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -206,7 +220,8 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
} }
enc.err = nil enc.err = nil
state := newEncoderState(enc, new(bytes.Buffer)) enc.byteBuf.Reset()
state := enc.newEncoderState(&enc.byteBuf)
enc.sendTypeDescriptor(enc.writer(), state, ut) enc.sendTypeDescriptor(enc.writer(), state, ut)
enc.sendTypeId(state, ut) enc.sendTypeId(state, ut)
...@@ -215,12 +230,11 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { ...@@ -215,12 +230,11 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
} }
// Encode the object. // Encode the object.
err = enc.encode(state.b, value, ut) enc.encode(state.b, value, ut)
if err != nil { if enc.err == nil {
enc.setError(err)
} else {
enc.writeMessage(enc.writer(), state.b) enc.writeMessage(enc.writer(), state.b)
} }
enc.freeEncoderState(state)
return enc.err return enc.err
} }
// Copyright 20011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains tests of the GobEncoder/GobDecoder support.
package gob
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
// Types that implement the GobEncoder/Decoder interfaces.
type ByteStruct struct {
a byte // not an exported field
}
type StringStruct struct {
s string // not an exported field
}
type Gobber int
type ValueGobber string // encodes with a value, decodes with a pointer.
// The relevant methods
func (g *ByteStruct) GobEncode() ([]byte, os.Error) {
b := make([]byte, 3)
b[0] = g.a
b[1] = g.a + 1
b[2] = g.a + 2
return b, nil
}
func (g *ByteStruct) GobDecode(data []byte) os.Error {
if g == nil {
return os.ErrorString("NIL RECEIVER")
}
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
g.a = data[0]
for i, c := range data {
if c != g.a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
return nil
}
func (g *StringStruct) GobEncode() ([]byte, os.Error) {
return []byte(g.s), nil
}
func (g *StringStruct) GobDecode(data []byte) os.Error {
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
a := data[0]
for i, c := range data {
if c != a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
g.s = string(data)
return nil
}
func (g *Gobber) GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
}
func (g *Gobber) GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
return err
}
func (v ValueGobber) GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil
}
func (v *ValueGobber) GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
return err
}
// Structs that include GobEncodable fields.
type GobTest0 struct {
X int // guarantee we have something in common with GobTest*
G *ByteStruct
}
type GobTest1 struct {
X int // guarantee we have something in common with GobTest*
G *StringStruct
}
type GobTest2 struct {
X int // guarantee we have something in common with GobTest*
G string // not a GobEncoder - should give us errors
}
type GobTest3 struct {
X int // guarantee we have something in common with GobTest*
G *Gobber
}
type GobTest4 struct {
X int // guarantee we have something in common with GobTest*
V ValueGobber
}
type GobTest5 struct {
X int // guarantee we have something in common with GobTest*
V *ValueGobber
}
type GobTestIgnoreEncoder struct {
X int // guarantee we have something in common with GobTest*
}
type GobTestValueEncDec struct {
X int // guarantee we have something in common with GobTest*
G StringStruct // not a pointer.
}
type GobTestIndirectEncDec struct {
X int // guarantee we have something in common with GobTest*
G ***StringStruct // indirections to the receiver.
}
func TestGobEncoderField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// Now a field that's not a structure.
b.Reset()
gobber := Gobber(23)
err = enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest3)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if *y.G != 23 {
t.Errorf("expected '23 got %d", *y.G)
}
}
// Even though the field is a value, we can still take its address
// and should be able to call the methods.
func TestGobEncoderValueField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTestValueEncDec{17, StringStruct{"HIJKL"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestValueEncDec)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.s != "HIJKL" {
t.Errorf("expected `HIJKL` got %s", x.G.s)
}
}
// GobEncode/Decode should work even if the value is
// more indirect than the receiver.
func TestGobEncoderIndirectField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
s := &StringStruct{"HIJKL"}
sp := &s
err := enc.Encode(GobTestIndirectEncDec{17, &sp})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIndirectEncDec)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if (***x.G).s != "HIJKL" {
t.Errorf("expected `HIJKL` got %s", (***x.G).s)
}
}
// As long as the fields have the same name and implement the
// interface, we can cross-connect them. Not sure it's useful
// and may even be bad but it works and it's hard to prevent
// without exposing the contents of the object, which would
// defeat the purpose.
func TestGobEncoderFieldsOfDifferentType(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// now the other direction, byte in field to string in field
b.Reset()
err = enc.Encode(GobTest0{17, &ByteStruct{'X'}})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest1)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if y.G.s != "XYZ" {
t.Fatalf("expected `XYZ` got %c", y.G.s)
}
}
// Test that we can encode a value and decode into a pointer.
func TestGobEncoderValueEncoder(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest4{17, ValueGobber("hello")})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest5)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if *x.V != "hello" {
t.Errorf("expected `hello` got %s", x.V)
}
}
func TestGobEncoderFieldTypeError(t *testing.T) {
// GobEncoder to non-decoder: error
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := &GobTest2{}
err = dec.Decode(x)
if err == nil {
t.Fatal("expected decode error for mismatched fields (encoder to non-decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
// Non-encoder to GobDecoder: error
b.Reset()
err = enc.Encode(GobTest2{17, "ABC"})
if err != nil {
t.Fatal("encode error:", err)
}
y := &GobTest1{}
err = dec.Decode(y)
if err == nil {
t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
}
// Even though ByteStruct is a struct, it's treated as a singleton at the top level.
func TestGobEncoderStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(&ByteStruct{'A'})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(ByteStruct)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.a != 'A' {
t.Errorf("expected 'A' got %c", x.a)
}
}
func TestGobEncoderNonStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(Gobber(1234))
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
var x Gobber
err = dec.Decode(&x)
if err != nil {
t.Fatal("decode error:", err)
}
if x != 1234 {
t.Errorf("expected 1234 got %c", x)
}
}
func TestGobEncoderIgnoreStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
func TestGobEncoderIgnoreNonStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
gobber := Gobber(23)
err := enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"bytes"
"fmt"
"io"
"os"
"runtime"
"testing"
)
type Bench struct {
A int
B float64
C string
D []byte
}
func benchmarkEndToEnd(r io.Reader, w io.Writer, b *testing.B) {
b.StopTimer()
enc := NewEncoder(w)
dec := NewDecoder(r)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
b.StartTimer()
for i := 0; i < b.N; i++ {
if enc.Encode(bench) != nil {
panic("encode error")
}
if dec.Decode(bench) != nil {
panic("decode error")
}
}
}
func BenchmarkEndToEndPipe(b *testing.B) {
r, w, err := os.Pipe()
if err != nil {
panic("can't get pipe:" + err.String())
}
benchmarkEndToEnd(r, w, b)
}
func BenchmarkEndToEndByteBuffer(b *testing.B) {
var buf bytes.Buffer
benchmarkEndToEnd(&buf, &buf, b)
}
func TestCountEncodeMallocs(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
mallocs := 0 - runtime.MemStats.Mallocs
const count = 1000
for i := 0; i < count; i++ {
err := enc.Encode(bench)
if err != nil {
t.Fatal("encode:", err)
}
}
mallocs += runtime.MemStats.Mallocs
fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count)
}
func TestCountDecodeMallocs(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
const count = 1000
for i := 0; i < count; i++ {
err := enc.Encode(bench)
if err != nil {
t.Fatal("encode:", err)
}
}
dec := NewDecoder(&buf)
mallocs := 0 - runtime.MemStats.Mallocs
for i := 0; i < count; i++ {
*bench = Bench{}
err := dec.Decode(&bench)
if err != nil {
t.Fatal("decode:", err)
}
}
mallocs += runtime.MemStats.Mallocs
fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count)
}
...@@ -26,7 +26,7 @@ var basicTypes = []typeT{ ...@@ -26,7 +26,7 @@ var basicTypes = []typeT{
func getTypeUnlocked(name string, rt reflect.Type) gobType { func getTypeUnlocked(name string, rt reflect.Type) gobType {
typeLock.Lock() typeLock.Lock()
defer typeLock.Unlock() defer typeLock.Unlock()
t, err := getType(name, rt) t, err := getBaseType(name, rt)
if err != nil { if err != nil {
panic("getTypeUnlocked: " + err.String()) panic("getTypeUnlocked: " + err.String())
} }
...@@ -126,27 +126,27 @@ func TestMapType(t *testing.T) { ...@@ -126,27 +126,27 @@ func TestMapType(t *testing.T) {
} }
type Bar struct { type Bar struct {
x string X string
} }
// This structure has pointers and refers to itself, making it a good test case. // This structure has pointers and refers to itself, making it a good test case.
type Foo struct { type Foo struct {
a int A int
b int32 // will become int B int32 // will become int
c string C string
d []byte D []byte
e *float64 // will become float64 E *float64 // will become float64
f ****float64 // will become float64 F ****float64 // will become float64
g *Bar G *Bar
h *Bar // should not interpolate the definition of Bar again H *Bar // should not interpolate the definition of Bar again
i *Foo // will not explode I *Foo // will not explode
} }
func TestStructType(t *testing.T) { func TestStructType(t *testing.T) {
sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{})) sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{}))
str := sstruct.string() str := sstruct.string()
// If we can print it correctly, we built it correctly. // If we can print it correctly, we built it correctly.
expected := "Foo = struct { a int; b int; c string; d bytes; e float; f float; g Bar = struct { x string; }; h Bar; i Foo; }" expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }"
if str != expected { if str != expected {
t.Errorf("struct printed as %q; expected %q", str, expected) t.Errorf("struct printed as %q; expected %q", str, expected)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment