Commit 8039ca76 by Ian Lance Taylor

Update to current version of Go library.

From-SVN: r171427
parent 7114321e
94d654be2064
31d7feb9281b
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.
......@@ -19,6 +19,7 @@ import (
"hash/crc32"
"encoding/binary"
"io"
"io/ioutil"
"os"
)
......@@ -109,7 +110,7 @@ func (f *File) Open() (rc io.ReadCloser, err os.Error) {
r := io.NewSectionReader(f.zipr, off+f.bodyOffset, size)
switch f.Method {
case 0: // store (no compression)
rc = nopCloser{r}
rc = ioutil.NopCloser(r)
case 8: // DEFLATE
rc = flate.NewReader(r)
default:
......@@ -147,12 +148,6 @@ func (r *checksumReader) Read(b []byte) (n int, err os.Error) {
func (r *checksumReader) Close() os.Error { return r.rc.Close() }
type nopCloser struct {
io.Reader
}
func (f nopCloser) Close() os.Error { return nil }
func readFileHeader(f *File, r io.Reader) (err os.Error) {
defer func() {
if rerr, ok := recover().(os.Error); ok {
......
......@@ -8,6 +8,7 @@ package big
import (
"fmt"
"os"
"rand"
)
......@@ -393,62 +394,19 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
}
// SetBytes interprets b as the bytes of a big-endian, unsigned integer and
// sets z to that value.
func (z *Int) SetBytes(b []byte) *Int {
const s = _S
z.abs = z.abs.make((len(b) + s - 1) / s)
j := 0
for len(b) >= s {
var w Word
for i := s; i > 0; i-- {
w <<= 8
w |= Word(b[len(b)-i])
}
z.abs[j] = w
j++
b = b[0 : len(b)-s]
}
if len(b) > 0 {
var w Word
for i := len(b); i > 0; i-- {
w <<= 8
w |= Word(b[len(b)-i])
}
z.abs[j] = w
}
z.abs = z.abs.norm()
// SetBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z *Int) SetBytes(buf []byte) *Int {
z.abs = z.abs.setBytes(buf)
z.neg = false
return z
}
// Bytes returns the absolute value of x as a big-endian byte array.
// Bytes returns the absolute value of z as a big-endian byte slice.
func (z *Int) Bytes() []byte {
const s = _S
b := make([]byte, len(z.abs)*s)
for i, w := range z.abs {
wordBytes := b[(len(z.abs)-i-1)*s : (len(z.abs)-i)*s]
for j := s - 1; j >= 0; j-- {
wordBytes[j] = byte(w)
w >>= 8
}
}
i := 0
for i < len(b) && b[i] == 0 {
i++
}
return b[i:]
buf := make([]byte, len(z.abs)*_S)
return buf[z.abs.bytes(buf):]
}
......@@ -739,3 +697,34 @@ func (z *Int) Not(x *Int) *Int {
z.neg = true // z cannot be zero if x is positive
return z
}
// Gob codec version. Permits backward-compatible changes to the encoding.
const version byte = 1
// GobEncode implements the gob.GobEncoder interface.
func (z *Int) GobEncode() ([]byte, os.Error) {
buf := make([]byte, len(z.abs)*_S+1) // extra byte for version and sign bit
i := z.abs.bytes(buf) - 1 // i >= 0
b := version << 1 // make space for sign bit
if z.neg {
b |= 1
}
buf[i] = b
return buf[i:], nil
}
// GobDecode implements the gob.GobDecoder interface.
func (z *Int) GobDecode(buf []byte) os.Error {
if len(buf) == 0 {
return os.NewError("Int.GobDecode: no data")
}
b := buf[0]
if b>>1 != version {
return os.NewError(fmt.Sprintf("Int.GobDecode: encoding version %d not supported", b>>1))
}
z.neg = b&1 != 0
z.abs = z.abs.setBytes(buf[1:])
return nil
}
......@@ -8,6 +8,7 @@ import (
"bytes"
"encoding/hex"
"fmt"
"gob"
"testing"
"testing/quick"
)
......@@ -1053,3 +1054,41 @@ func TestModInverse(t *testing.T) {
}
}
}
var gobEncodingTests = []string{
"0",
"1",
"2",
"10",
"42",
"1234567890",
"298472983472983471903246121093472394872319615612417471234712061",
}
func TestGobEncoding(t *testing.T) {
var medium bytes.Buffer
enc := gob.NewEncoder(&medium)
dec := gob.NewDecoder(&medium)
for i, test := range gobEncodingTests {
for j := 0; j < 2; j++ {
medium.Reset() // empty buffer for each test case (in case of failures)
stest := test
if j == 0 {
stest = "-" + test
}
var tx Int
tx.SetString(stest, 10)
if err := enc.Encode(&tx); err != nil {
t.Errorf("#%d%c: encoding failed: %s", i, 'a'+j, err)
}
var rx Int
if err := dec.Decode(&rx); err != nil {
t.Errorf("#%d%c: decoding failed: %s", i, 'a'+j, err)
}
if rx.Cmp(&tx) != 0 {
t.Errorf("#%d%c: transmission failed: got %s want %s", i, 'a'+j, &rx, &tx)
}
}
}
}
......@@ -1065,3 +1065,50 @@ NextRandom:
return true
}
// bytes writes the value of z into buf using big-endian encoding.
// len(buf) must be >= len(z)*_S. The value of z is encoded in the
// slice buf[i:]. The number i of unused bytes at the beginning of
// buf is returned as result.
func (z nat) bytes(buf []byte) (i int) {
i = len(buf)
for _, d := range z {
for j := 0; j < _S; j++ {
i--
buf[i] = byte(d)
d >>= 8
}
}
for i < len(buf) && buf[i] == 0 {
i++
}
return
}
// setBytes interprets buf as the bytes of a big-endian unsigned
// integer, sets z to that value, and returns z.
func (z nat) setBytes(buf []byte) nat {
z = z.make((len(buf) + _S - 1) / _S)
k := 0
s := uint(0)
var d Word
for i := len(buf); i > 0; i-- {
d |= Word(buf[i-1]) << s
if s += 8; s == _S*8 {
z[k] = d
k++
s = 0
d = 0
}
}
if k < len(z) {
z[k] = d
}
return z.norm()
}
......@@ -2,9 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bufio
package bufio_test
import (
. "bufio"
"bytes"
"fmt"
"io"
......@@ -502,9 +503,8 @@ func TestWriteString(t *testing.T) {
b.WriteString("7890") // easy after flush
b.WriteString("abcdefghijklmnopqrstuvwxy") // hard
b.WriteString("z")
b.Flush()
if b.err != nil {
t.Error("WriteString", b.err)
if err := b.Flush(); err != nil {
t.Error("WriteString", err)
}
s := "01234567890abcdefghijklmnopqrstuvwxyz"
if string(buf.Bytes()) != s {
......
......@@ -191,9 +191,16 @@ func testSync(t *testing.T, level int, input []byte, name string) {
t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
return
}
if i == 0 && buf.buf.Len() != 0 {
t.Errorf("testSync/%d (%d, %d, %s): extra data after %d", i, level, len(input), name, hi-lo)
}
// This test originally checked that after reading
// the first half of the input, there was nothing left
// in the read buffer (buf.buf.Len() != 0) but that is
// not necessarily the case: the write Flush may emit
// some extra framing bits that are not necessary
// to process to obtain the first half of the uncompressed
// data. The test ran correctly most of the time, because
// the background goroutine had usually read even
// those extra bits by now, but it's not a useful thing to
// check.
buf.WriteMode()
}
buf.ReadMode()
......
......@@ -9,6 +9,7 @@ import (
"io"
"io/ioutil"
"os"
"runtime"
"strconv"
"strings"
"testing"
......@@ -117,16 +118,34 @@ func (devNull) Write(p []byte) (int, os.Error) {
return len(p), nil
}
func BenchmarkDecoder(b *testing.B) {
func benchmarkDecoder(b *testing.B, n int) {
b.StopTimer()
b.SetBytes(int64(n))
buf0, _ := ioutil.ReadFile("../testdata/e.txt")
buf0 = buf0[:10000]
compressed := bytes.NewBuffer(nil)
w := NewWriter(compressed, LSB, 8)
io.Copy(w, bytes.NewBuffer(buf0))
for i := 0; i < n; i += len(buf0) {
io.Copy(w, bytes.NewBuffer(buf0))
}
w.Close()
buf1 := compressed.Bytes()
buf0, compressed, w = nil, nil, nil
runtime.GC()
b.StartTimer()
for i := 0; i < b.N; i++ {
io.Copy(devNull{}, NewReader(bytes.NewBuffer(buf1), LSB, 8))
}
}
func BenchmarkDecoder1e4(b *testing.B) {
benchmarkDecoder(b, 1e4)
}
func BenchmarkDecoder1e5(b *testing.B) {
benchmarkDecoder(b, 1e5)
}
func BenchmarkDecoder1e6(b *testing.B) {
benchmarkDecoder(b, 1e6)
}
......@@ -8,6 +8,7 @@ import (
"io"
"io/ioutil"
"os"
"runtime"
"testing"
)
......@@ -99,13 +100,33 @@ func TestWriter(t *testing.T) {
}
}
func BenchmarkEncoder(b *testing.B) {
func benchmarkEncoder(b *testing.B, n int) {
b.StopTimer()
buf, _ := ioutil.ReadFile("../testdata/e.txt")
b.SetBytes(int64(n))
buf0, _ := ioutil.ReadFile("../testdata/e.txt")
buf0 = buf0[:10000]
buf1 := make([]byte, n)
for i := 0; i < n; i += len(buf0) {
copy(buf1[i:], buf0)
}
buf0 = nil
runtime.GC()
b.StartTimer()
for i := 0; i < b.N; i++ {
w := NewWriter(devNull{}, LSB, 8)
w.Write(buf)
w.Write(buf1)
w.Close()
}
}
func BenchmarkEncoder1e4(b *testing.B) {
benchmarkEncoder(b, 1e4)
}
func BenchmarkEncoder1e5(b *testing.B) {
benchmarkEncoder(b, 1e5)
}
func BenchmarkEncoder1e6(b *testing.B) {
benchmarkEncoder(b, 1e6)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ecdsa implements the Elliptic Curve Digital Signature Algorithm, as
// defined in FIPS 186-3.
package ecdsa
// References:
// [NSA]: Suite B implementor's guide to FIPS 186-3,
// http://www.nsa.gov/ia/_files/ecdsa.pdf
// [SECG]: SECG, SEC1
// http://www.secg.org/download/aid-780/sec1-v2.pdf
import (
"big"
"crypto/elliptic"
"io"
"os"
)
// PublicKey represents an ECDSA public key.
type PublicKey struct {
*elliptic.Curve
X, Y *big.Int
}
// PrivateKey represents a ECDSA private key.
type PrivateKey struct {
PublicKey
D *big.Int
}
var one = new(big.Int).SetInt64(1)
// randFieldElement returns a random element of the field underlying the given
// curve using the procedure given in [NSA] A.2.1.
func randFieldElement(c *elliptic.Curve, rand io.Reader) (k *big.Int, err os.Error) {
b := make([]byte, c.BitSize/8+8)
_, err = rand.Read(b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(c.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
// GenerateKey generates a public&private key pair.
func GenerateKey(c *elliptic.Curve, rand io.Reader) (priv *PrivateKey, err os.Error) {
k, err := randFieldElement(c, rand)
if err != nil {
return
}
priv = new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return
}
// hashToInt converts a hash value to an integer. There is some disagreement
// about how this is done. [NSA] suggests that this is done in the obvious
// manner, but [SECG] truncates the hash to the bit-length of the curve order
// first. We follow [SECG] because that's what OpenSSL does.
func hashToInt(hash []byte, c *elliptic.Curve) *big.Int {
orderBits := c.N.BitLen()
orderBytes := (orderBits + 7) / 8
if len(hash) > orderBytes {
hash = hash[:orderBytes]
}
ret := new(big.Int).SetBytes(hash)
excess := orderBytes*8 - orderBits
if excess > 0 {
ret.Rsh(ret, uint(excess))
}
return ret
}
// Sign signs an arbitrary length hash (which should be the result of hashing a
// larger message) using the private key, priv. It returns the signature as a
// pair of integers. The security of the private key depends on the entropy of
// rand.
func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err os.Error) {
// See [NSA] 3.4.1
c := priv.PublicKey.Curve
var k, kInv *big.Int
for {
for {
k, err = randFieldElement(c, rand)
if err != nil {
r = nil
return
}
kInv = new(big.Int).ModInverse(k, c.N)
r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
r.Mod(r, priv.Curve.N)
if r.Sign() != 0 {
break
}
}
e := hashToInt(hash, c)
s = new(big.Int).Mul(priv.D, r)
s.Add(s, e)
s.Mul(s, kInv)
s.Mod(s, priv.PublicKey.Curve.N)
if s.Sign() != 0 {
break
}
}
return
}
// Verify verifies the signature in r, s of hash using the public key, pub. It
// returns true iff the signature is valid.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
// See [NSA] 3.4.2
c := pub.Curve
if r.Sign() == 0 || s.Sign() == 0 {
return false
}
if r.Cmp(c.N) >= 0 || s.Cmp(c.N) >= 0 {
return false
}
e := hashToInt(hash, c)
w := new(big.Int).ModInverse(s, c.N)
u1 := e.Mul(e, w)
u2 := w.Mul(r, w)
x1, y1 := c.ScalarBaseMult(u1.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, u2.Bytes())
if x1.Cmp(x2) == 0 {
return false
}
x, _ := c.Add(x1, y1, x2, y2)
x.Mod(x, c.N)
return x.Cmp(r) == 0
}
......@@ -24,6 +24,7 @@ import (
// See http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
type Curve struct {
P *big.Int // the order of the underlying field
N *big.Int // the order of the base point
B *big.Int // the constant of the curve equation
Gx, Gy *big.Int // (x,y) of the base point
BitSize int // the size of the underlying field
......@@ -315,6 +316,7 @@ func initP224() {
// See FIPS 186-3, section D.2.2
p224 = new(Curve)
p224.P, _ = new(big.Int).SetString("26959946667150639794667015087019630673557916260026308143510066298881", 10)
p224.N, _ = new(big.Int).SetString("26959946667150639794667015087019625940457807714424391721682722368061", 10)
p224.B, _ = new(big.Int).SetString("b4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4", 16)
p224.Gx, _ = new(big.Int).SetString("b70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21", 16)
p224.Gy, _ = new(big.Int).SetString("bd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34", 16)
......@@ -325,6 +327,7 @@ func initP256() {
// See FIPS 186-3, section D.2.3
p256 = new(Curve)
p256.P, _ = new(big.Int).SetString("115792089210356248762697446949407573530086143415290314195533631308867097853951", 10)
p256.N, _ = new(big.Int).SetString("115792089210356248762697446949407573529996955224135760342422259061068512044369", 10)
p256.B, _ = new(big.Int).SetString("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16)
p256.Gx, _ = new(big.Int).SetString("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16)
p256.Gy, _ = new(big.Int).SetString("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16)
......@@ -335,6 +338,7 @@ func initP384() {
// See FIPS 186-3, section D.2.4
p384 = new(Curve)
p384.P, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319", 10)
p384.N, _ = new(big.Int).SetString("39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643", 10)
p384.B, _ = new(big.Int).SetString("b3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef", 16)
p384.Gx, _ = new(big.Int).SetString("aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7", 16)
p384.Gy, _ = new(big.Int).SetString("3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f", 16)
......@@ -345,6 +349,7 @@ func initP521() {
// See FIPS 186-3, section D.2.5
p521 = new(Curve)
p521.P, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397656052122559640661454554977296311391480858037121987999716643812574028291115057151", 10)
p521.N, _ = new(big.Int).SetString("6864797660130609714981900799081393217269435300143305409394463459185543183397655394245057746333217197532963996371363321113864768612440380340372808892707005449", 10)
p521.B, _ = new(big.Int).SetString("051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00", 16)
p521.Gx, _ = new(big.Int).SetString("c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66", 16)
p521.Gy, _ = new(big.Int).SetString("11839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650", 16)
......
......@@ -7,6 +7,7 @@
package packet
import (
"big"
"crypto/aes"
"crypto/cast5"
"crypto/cipher"
......@@ -166,10 +167,10 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
return
}
// serialiseHeader writes an OpenPGP packet header to w. See RFC 4880, section
// serializeHeader writes an OpenPGP packet header to w. See RFC 4880, section
// 4.2.
func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
var buf [5]byte
func serializeHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
var buf [6]byte
var n int
buf[0] = 0x80 | 0x40 | byte(ptype)
......@@ -178,16 +179,16 @@ func serialiseHeader(w io.Writer, ptype packetType, length int) (err os.Error) {
n = 2
} else if length < 8384 {
length -= 192
buf[1] = byte(length >> 8)
buf[1] = 192 + byte(length>>8)
buf[2] = byte(length)
n = 3
} else {
buf[0] = 255
buf[1] = byte(length >> 24)
buf[2] = byte(length >> 16)
buf[3] = byte(length >> 8)
buf[4] = byte(length)
n = 5
buf[1] = 255
buf[2] = byte(length >> 24)
buf[3] = byte(length >> 16)
buf[4] = byte(length >> 8)
buf[5] = byte(length)
n = 6
}
_, err = w.Write(buf[:n])
......@@ -371,7 +372,7 @@ func (cipher CipherFunction) new(key []byte) (block cipher.Block) {
// readMPI reads a big integer from r. The bit length returned is the bit
// length that was specified in r. This is preserved so that the integer can be
// reserialised exactly.
// reserialized exactly.
func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
var buf [2]byte
_, err = readFull(r, buf[0:])
......@@ -385,7 +386,7 @@ func readMPI(r io.Reader) (mpi []byte, bitLength uint16, err os.Error) {
return
}
// writeMPI serialises a big integer to r.
// writeMPI serializes a big integer to w.
func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
_, err = w.Write([]byte{byte(bitLength >> 8), byte(bitLength)})
if err == nil {
......@@ -393,3 +394,8 @@ func writeMPI(w io.Writer, bitLength uint16, mpiBytes []byte) (err os.Error) {
}
return
}
// writeBig serializes a *big.Int to w.
func writeBig(w io.Writer, i *big.Int) os.Error {
return writeMPI(w, uint16(i.BitLen()), i.Bytes())
}
......@@ -190,3 +190,23 @@ func TestReadHeader(t *testing.T) {
}
}
}
func TestSerializeHeader(t *testing.T) {
tag := packetTypePublicKey
lengths := []int{0, 1, 2, 64, 192, 193, 8000, 8384, 8385, 10000}
for _, length := range lengths {
buf := bytes.NewBuffer(nil)
serializeHeader(buf, tag, length)
tag2, length2, _, err := readHeader(buf)
if err != nil {
t.Errorf("length %d, err: %s", length, err)
}
if tag2 != tag {
t.Errorf("length %d, tag incorrect (got %d, want %d)", length, tag2, tag)
}
if int(length2) != length {
t.Errorf("length %d, length incorrect (got %d)", length, length2)
}
}
}
......@@ -8,6 +8,7 @@ import (
"big"
"bytes"
"crypto/cipher"
"crypto/dsa"
"crypto/openpgp/error"
"crypto/openpgp/s2k"
"crypto/rsa"
......@@ -134,7 +135,16 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) os.Error {
}
func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
// TODO(agl): support DSA and ECDSA private keys.
switch pk.PublicKey.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoRSAEncryptOnly:
return pk.parseRSAPrivateKey(data)
case PubKeyAlgoDSA:
return pk.parseDSAPrivateKey(data)
}
panic("impossible")
}
func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err os.Error) {
rsaPub := pk.PublicKey.PublicKey.(*rsa.PublicKey)
rsaPriv := new(rsa.PrivateKey)
rsaPriv.PublicKey = *rsaPub
......@@ -162,3 +172,22 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err os.Error) {
return nil
}
func (pk *PrivateKey) parseDSAPrivateKey(data []byte) (err os.Error) {
dsaPub := pk.PublicKey.PublicKey.(*dsa.PublicKey)
dsaPriv := new(dsa.PrivateKey)
dsaPriv.PublicKey = *dsaPub
buf := bytes.NewBuffer(data)
x, _, err := readMPI(buf)
if err != nil {
return
}
dsaPriv.X = new(big.Int).SetBytes(x)
pk.PrivateKey = dsaPriv
pk.Encrypted = false
pk.encryptedData = nil
return nil
}
......@@ -11,6 +11,7 @@ import (
"crypto/rsa"
"crypto/sha1"
"encoding/binary"
"fmt"
"hash"
"io"
"os"
......@@ -178,12 +179,6 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
return error.InvalidArgumentError("public key cannot generate signatures")
}
rsaPublicKey, ok := pk.PublicKey.(*rsa.PublicKey)
if !ok {
// TODO(agl): support DSA and ECDSA keys.
return error.UnsupportedError("non-RSA public key")
}
signed.Write(sig.HashSuffix)
hashBytes := signed.Sum()
......@@ -191,11 +186,28 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err os.E
return error.SignatureError("hash tag doesn't match")
}
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.Signature)
if err != nil {
return error.SignatureError("RSA verification failure")
if pk.PubKeyAlgo != sig.PubKeyAlgo {
return error.InvalidArgumentError("public key and signature use different algorithms")
}
return nil
switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature)
if err != nil {
return error.SignatureError("RSA verification failure")
}
return nil
case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
if !dsa.Verify(dsaPublicKey, hashBytes, sig.DSASigR, sig.DSASigS) {
return error.SignatureError("DSA verification failure")
}
return nil
default:
panic("shouldn't happen")
}
panic("unreachable")
}
// VerifyKeySignature returns nil iff sig is a valid signature, make by this
......@@ -239,9 +251,21 @@ func (pk *PublicKey) VerifyUserIdSignature(id string, sig *Signature) (err os.Er
return pk.VerifySignature(h, sig)
}
// KeyIdString returns the public key's fingerprint in capital hex
// (e.g. "6C7EE1B8621CC013").
func (pk *PublicKey) KeyIdString() string {
return fmt.Sprintf("%X", pk.Fingerprint[12:20])
}
// KeyIdShortString returns the short form of public key's fingerprint
// in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
func (pk *PublicKey) KeyIdShortString() string {
return fmt.Sprintf("%X", pk.Fingerprint[16:20])
}
// A parsedMPI is used to store the contents of a big integer, along with the
// bit length that was specified in the original input. This allows the MPI to
// be reserialised exactly.
// be reserialized exactly.
type parsedMPI struct {
bytes []byte
bitLength uint16
......
......@@ -16,9 +16,11 @@ var pubKeyTests = []struct {
creationTime uint32
pubKeyAlgo PublicKeyAlgorithm
keyId uint64
keyIdString string
keyIdShort string
}{
{rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb},
{dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed},
{rsaPkDataHex, rsaFingerprintHex, 0x4d3c5c10, PubKeyAlgoRSA, 0xa34d7e18c20c31bb, "A34D7E18C20C31BB", "C20C31BB"},
{dsaPkDataHex, dsaFingerprintHex, 0x4d432f89, PubKeyAlgoDSA, 0x8e8fbe54062f19ed, "8E8FBE54062F19ED", "062F19ED"},
}
func TestPublicKeyRead(t *testing.T) {
......@@ -46,6 +48,12 @@ func TestPublicKeyRead(t *testing.T) {
if pk.KeyId != test.keyId {
t.Errorf("#%d: bad keyid got:%x want:%x", i, pk.KeyId, test.keyId)
}
if g, e := pk.KeyIdString(), test.keyIdString; g != e {
t.Errorf("#%d: bad KeyIdString got:%q want:%q", i, g, e)
}
if g, e := pk.KeyIdShortString(), test.keyIdShort; g != e {
t.Errorf("#%d: bad KeyIdShortString got:%q want:%q", i, g, e)
}
}
}
......
......@@ -5,7 +5,9 @@
package packet
import (
"big"
"crypto"
"crypto/dsa"
"crypto/openpgp/error"
"crypto/openpgp/s2k"
"crypto/rand"
......@@ -29,7 +31,9 @@ type Signature struct {
// of bad signed data.
HashTag [2]byte
CreationTime uint32 // Unix epoch time
Signature []byte
RSASignature []byte
DSASigR, DSASigS *big.Int
// The following are optional so are nil when not included in the
// signature.
......@@ -66,7 +70,7 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
sig.SigType = SignatureType(buf[0])
sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1])
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA:
default:
err = error.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo)))
return
......@@ -122,8 +126,20 @@ func (sig *Signature) parse(r io.Reader) (err os.Error) {
return
}
// We have already checked that the public key algorithm is RSA.
sig.Signature, _, err = readMPI(r)
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sig.RSASignature, _, err = readMPI(r)
case PubKeyAlgoDSA:
var rBytes, sBytes []byte
rBytes, _, err = readMPI(r)
sig.DSASigR = new(big.Int).SetBytes(rBytes)
if err == nil {
sBytes, _, err = readMPI(r)
sig.DSASigS = new(big.Int).SetBytes(sBytes)
}
default:
panic("unreachable")
}
return
}
......@@ -316,8 +332,8 @@ func subpacketLengthLength(length int) int {
return 5
}
// serialiseSubpacketLength marshals the given length into to.
func serialiseSubpacketLength(to []byte, length int) int {
// serializeSubpacketLength marshals the given length into to.
func serializeSubpacketLength(to []byte, length int) int {
if length < 192 {
to[0] = byte(length)
return 1
......@@ -336,7 +352,7 @@ func serialiseSubpacketLength(to []byte, length int) int {
return 5
}
// subpacketsLength returns the serialised length, in bytes, of the given
// subpacketsLength returns the serialized length, in bytes, of the given
// subpackets.
func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
for _, subpacket := range subpackets {
......@@ -349,11 +365,11 @@ func subpacketsLength(subpackets []outputSubpacket, hashed bool) (length int) {
return
}
// serialiseSubpackets marshals the given subpackets into to.
func serialiseSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
// serializeSubpackets marshals the given subpackets into to.
func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) {
for _, subpacket := range subpackets {
if subpacket.hashed == hashed {
n := serialiseSubpacketLength(to, len(subpacket.contents)+1)
n := serializeSubpacketLength(to, len(subpacket.contents)+1)
to[n] = byte(subpacket.subpacketType)
to = to[1+n:]
n = copy(to, subpacket.contents)
......@@ -381,7 +397,7 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
}
sig.HashSuffix[4] = byte(hashedSubpacketsLen >> 8)
sig.HashSuffix[5] = byte(hashedSubpacketsLen)
serialiseSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
serializeSubpackets(sig.HashSuffix[6:l], sig.outSubpackets, true)
trailer := sig.HashSuffix[l:]
trailer[0] = 4
trailer[1] = 0xff
......@@ -392,32 +408,66 @@ func (sig *Signature) buildHashSuffix() (err os.Error) {
return
}
// SignRSA signs a message with an RSA private key. The hash, h, must contain
// the hash of message to be signed and will be mutated by this function.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err os.Error) {
err = sig.buildHashSuffix()
if err != nil {
return
}
h.Write(sig.HashSuffix)
digest := h.Sum()
digest = h.Sum()
copy(sig.HashTag[:], digest)
sig.Signature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
return
}
// Serialize marshals sig to w. SignRSA must have been called first.
// SignRSA signs a message with an RSA private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignRSA(h hash.Hash, priv *rsa.PrivateKey) (err os.Error) {
digest, err := sig.signPrepareHash(h)
if err != nil {
return
}
sig.RSASignature, err = rsa.SignPKCS1v15(rand.Reader, priv, sig.Hash, digest)
return
}
// SignDSA signs a message with a DSA private key. The hash, h, must contain
// the hash of the message to be signed and will be mutated by this function.
// On success, the signature is stored in sig. Call Serialize to write it out.
func (sig *Signature) SignDSA(h hash.Hash, priv *dsa.PrivateKey) (err os.Error) {
digest, err := sig.signPrepareHash(h)
if err != nil {
return
}
sig.DSASigR, sig.DSASigS, err = dsa.Sign(rand.Reader, priv, digest)
return
}
// Serialize marshals sig to w. SignRSA or SignDSA must have been called first.
func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if sig.Signature == nil {
return error.InvalidArgumentError("Signature: need to call SignRSA before Serialize")
if sig.RSASignature == nil && sig.DSASigR == nil {
return error.InvalidArgumentError("Signature: need to call SignRSA or SignDSA before Serialize")
}
sigLength := 0
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
sigLength = len(sig.RSASignature)
case PubKeyAlgoDSA:
sigLength = 2 /* MPI length */
sigLength += (sig.DSASigR.BitLen() + 7) / 8
sigLength += 2 /* MPI length */
sigLength += (sig.DSASigS.BitLen() + 7) / 8
default:
panic("impossible")
}
unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false)
length := len(sig.HashSuffix) - 6 /* trailer not included */ +
2 /* length of unhashed subpackets */ + unhashedSubpacketsLen +
2 /* hash tag */ + 2 /* length of signature MPI */ + len(sig.Signature)
err = serialiseHeader(w, packetTypeSignature, length)
2 /* hash tag */ + 2 /* length of signature MPI */ + sigLength
err = serializeHeader(w, packetTypeSignature, length)
if err != nil {
return
}
......@@ -430,7 +480,7 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen)
unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8)
unhashedSubpackets[1] = byte(unhashedSubpacketsLen)
serialiseSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false)
_, err = w.Write(unhashedSubpackets)
if err != nil {
......@@ -440,7 +490,19 @@ func (sig *Signature) Serialize(w io.Writer) (err os.Error) {
if err != nil {
return
}
return writeMPI(w, 8*uint16(len(sig.Signature)), sig.Signature)
switch sig.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly:
err = writeMPI(w, 8*uint16(len(sig.RSASignature)), sig.RSASignature)
case PubKeyAlgoDSA:
err = writeBig(w, sig.DSASigR)
if err == nil {
err = writeBig(w, sig.DSASigS)
}
default:
panic("impossible")
}
return
}
// outputSubpacket represents a subpacket to be marshaled.
......
......@@ -6,6 +6,7 @@ package openpgp
import (
"crypto"
"crypto/dsa"
"crypto/openpgp/armor"
"crypto/openpgp/error"
"crypto/openpgp/packet"
......@@ -39,7 +40,7 @@ func DetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
// ArmoredDetachSignText signs message (after canonicalising the line endings)
// with the private key from signer (which must already have been decrypted)
// and writes an armored signature to w.
func SignTextDetachedArmored(w io.Writer, signer *Entity, message io.Reader) os.Error {
func ArmoredDetachSignText(w io.Writer, signer *Entity, message io.Reader) os.Error {
return armoredDetachSign(w, signer, message, packet.SigTypeText)
}
......@@ -80,6 +81,9 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSASignOnly:
priv := signer.PrivateKey.PrivateKey.(*rsa.PrivateKey)
err = sig.SignRSA(h, priv)
case packet.PubKeyAlgoDSA:
priv := signer.PrivateKey.PrivateKey.(*dsa.PrivateKey)
err = sig.SignDSA(h, priv)
default:
err = error.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo)))
}
......
......@@ -18,7 +18,7 @@ func TestSignDetached(t *testing.T) {
t.Error(err)
}
testDetachedSignature(t, kring, out, signedInput, "check")
testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
}
func TestSignTextDetached(t *testing.T) {
......@@ -30,5 +30,17 @@ func TestSignTextDetached(t *testing.T) {
t.Error(err)
}
testDetachedSignature(t, kring, out, signedInput, "check")
testDetachedSignature(t, kring, out, signedInput, "check", testKey1KeyId)
}
func TestSignDetachedDSA(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(dsaTestKeyPrivateHex))
out := bytes.NewBuffer(nil)
message := bytes.NewBufferString(signedInput)
err := DetachSign(out, kring[0], message)
if err != nil {
t.Error(err)
}
testDetachedSignature(t, kring, out, signedInput, "check", testKey3KeyId)
}
......@@ -7,6 +7,7 @@ package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"io"
"io/ioutil"
"sync"
......@@ -95,6 +96,9 @@ type ConnectionState struct {
HandshakeComplete bool
CipherSuite uint16
NegotiatedProtocol string
// the certificate chain that was presented by the other side
PeerCertificates []*x509.Certificate
}
// A Config structure is used to configure a TLS client or server. After one
......
......@@ -762,6 +762,7 @@ func (c *Conn) ConnectionState() ConnectionState {
if c.handshakeComplete {
state.NegotiatedProtocol = c.clientProtocol
state.CipherSuite = c.cipherSuite
state.PeerCertificates = c.peerCertificates
}
return state
......@@ -776,15 +777,6 @@ func (c *Conn) OCSPResponse() []byte {
return c.ocspResponse
}
// PeerCertificates returns the certificate chain that was presented by the
// other side.
func (c *Conn) PeerCertificates() []*x509.Certificate {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.peerCertificates
}
// VerifyHostname checks that the peer certificate chain is valid for
// connecting to host. If so, it returns nil; if not, it returns an os.Error
// describing the problem.
......
......@@ -25,7 +25,7 @@ func main() {
priv, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
log.Exitf("failed to generate private key: %s", err)
log.Fatalf("failed to generate private key: %s", err)
return
}
......@@ -46,13 +46,13 @@ func main() {
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
log.Exitf("Failed to create certificate: %s", err)
log.Fatalf("Failed to create certificate: %s", err)
return
}
certOut, err := os.Open("cert.pem", os.O_WRONLY|os.O_CREAT, 0644)
if err != nil {
log.Exitf("failed to open cert.pem for writing: %s", err)
log.Fatalf("failed to open cert.pem for writing: %s", err)
return
}
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
......
......@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on OS X")
}
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0)
}
......@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on FreeBSD")
}
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0)
}
......@@ -1279,25 +1279,31 @@ func Attach(pid int) (Process, os.Error) {
return p, nil
}
// ForkExec forks the current process and execs argv0, stopping the
// new process after the exec syscall. See os.ForkExec for additional
// StartProcess forks the current process and execs argv0, stopping the
// new process after the exec syscall. See os.StartProcess for additional
// details.
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
sysattr := &syscall.ProcAttr{
Dir: attr.Dir,
Env: attr.Env,
Ptrace: true,
}
p := newProcess(-1)
// Create array of integer (system) fds.
intfd := make([]int, len(fd))
for i, f := range fd {
intfd := make([]int, len(attr.Files))
for i, f := range attr.Files {
if f == nil {
intfd[i] = -1
} else {
intfd[i] = f.Fd()
}
}
sysattr.Files = intfd
// Fork from the monitor thread so we get the right tracer pid.
err := p.do(func() os.Error {
pid, errno := syscall.PtraceForkExec(argv0, argv, envv, dir, intfd)
pid, _, errno := syscall.StartProcess(argv0, argv, sysattr)
if errno != 0 {
return &os.PathError{"fork/exec", argv0, os.Errno(errno)}
}
......
......@@ -12,6 +12,6 @@ func Attach(pid int) (Process, os.Error) {
return nil, os.NewError("debug/proc not implemented on windows")
}
func ForkExec(argv0 string, argv []string, envv []string, dir string, fd []*os.File) (Process, os.Error) {
func StartProcess(argv0 string, argv []string, attr *os.ProcAttr) (Process, os.Error) {
return Attach(0)
}
......@@ -75,17 +75,19 @@ func modeToFiles(mode, fd int) (*os.File, *os.File, os.Error) {
// Run starts the named binary running with
// arguments argv and environment envv.
// If the dir argument is not empty, the child changes
// into the directory before executing the binary.
// It returns a pointer to a new Cmd representing
// the command or an error.
//
// The parameters stdin, stdout, and stderr
// The arguments stdin, stdout, and stderr
// specify how to handle standard input, output, and error.
// The choices are DevNull (connect to /dev/null),
// PassThrough (connect to the current process's standard stream),
// Pipe (connect to an operating system pipe), and
// MergeWithStdout (only for standard error; use the same
// file descriptor as was used for standard output).
// If a parameter is Pipe, then the corresponding field (Stdin, Stdout, Stderr)
// If an argument is Pipe, then the corresponding field (Stdin, Stdout, Stderr)
// of the returned Cmd is the other end of the pipe.
// Otherwise the field in Cmd is nil.
func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int) (c *Cmd, err os.Error) {
......@@ -105,7 +107,7 @@ func Run(name string, argv, envv []string, dir string, stdin, stdout, stderr int
}
// Run command.
c.Process, err = os.StartProcess(name, argv, envv, dir, fd[0:])
c.Process, err = os.StartProcess(name, argv, &os.ProcAttr{Dir: dir, Files: fd[:], Env: envv})
if err != nil {
goto Error
}
......
......@@ -118,3 +118,55 @@ func TestAddEnvVar(t *testing.T) {
t.Fatal("close:", err)
}
}
var tryargs = []string{
`2`,
`2 `,
"2 \t",
`2" "`,
`2 ab `,
`2 "ab" `,
`2 \ `,
`2 \\ `,
`2 \" `,
`2 \`,
`2\`,
`2"`,
`2\"`,
`2 "`,
`2 \"`,
``,
`2 ^ `,
`2 \^`,
}
func TestArgs(t *testing.T) {
for _, a := range tryargs {
argv := []string{
"awk",
`BEGIN{printf("%s|%s|%s",ARGV[1],ARGV[2],ARGV[3])}`,
"/dev/null",
a,
"EOF",
}
exe, err := LookPath(argv[0])
if err != nil {
t.Fatal("run:", err)
}
cmd, err := Run(exe, argv, nil, "", DevNull, Pipe, DevNull)
if err != nil {
t.Fatal("run:", err)
}
buf, err := ioutil.ReadAll(cmd.Stdout)
if err != nil {
t.Fatal("read:", err)
}
expect := "/dev/null|" + a + "|EOF"
if string(buf) != expect {
t.Errorf("read: got %q expect %q", buf, expect)
}
if err = cmd.Close(); err != nil {
t.Fatal("close:", err)
}
}
}
......@@ -287,9 +287,6 @@ func (a *stmtCompiler) compile(s ast.Stmt) {
case *ast.SwitchStmt:
a.compileSwitchStmt(s)
case *ast.TypeCaseClause:
notimpl = true
case *ast.TypeSwitchStmt:
notimpl = true
......@@ -1012,13 +1009,13 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
a.diagAt(clause.Pos(), "switch statement must contain case clauses")
continue
}
if clause.Values == nil {
if clause.List == nil {
if hasDefault {
a.diagAt(clause.Pos(), "switch statement contains more than one default case")
}
hasDefault = true
} else {
ncases += len(clause.Values)
ncases += len(clause.List)
}
}
......@@ -1030,7 +1027,7 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
if !ok {
continue
}
for _, v := range clause.Values {
for _, v := range clause.List {
e := condbc.compileExpr(condbc.block, false, v)
switch {
case e == nil:
......@@ -1077,8 +1074,8 @@ func (a *stmtCompiler) compileSwitchStmt(s *ast.SwitchStmt) {
// Save jump PC's
pc := a.nextPC()
if clause.Values != nil {
for _ = range clause.Values {
if clause.List != nil {
for _ = range clause.List {
casePCs[i] = &pc
i++
}
......
......@@ -27,7 +27,7 @@ var stmtTests = []test{
CErr("i, u := 1, 2", atLeastOneDecl),
Val2("i, x := 2, f", "i", 2, "x", 1.0),
// Various errors
CErr("1 := 2", "left side of := must be a name"),
CErr("1 := 2", "expected identifier"),
CErr("c, a := 1, 1", "cannot assign"),
// Unpacking
Val2("x, y := oneTwo()", "x", 1, "y", 2),
......
......@@ -160,7 +160,7 @@ func cmdLoad(args []byte) os.Error {
} else {
fname = parts[0]
}
tproc, err = proc.ForkExec(fname, parts, os.Environ(), "", []*os.File{os.Stdin, os.Stdout, os.Stderr})
tproc, err = proc.StartProcess(fname, parts, &os.ProcAttr{Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}})
if err != nil {
return err
}
......
......@@ -269,7 +269,7 @@ func Iter() <-chan KeyValue {
}
func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.SetHeader("content-type", "application/json; charset=utf-8")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n")
first := true
for name, value := range vars {
......
......@@ -56,7 +56,7 @@
flag.Bool(...) // global options
flag.Parse() // parse leading command
subcmd := flag.Args(0)
subcmd := flag.Arg[0]
switch subcmd {
// add per-subcommand options
}
......@@ -68,6 +68,7 @@ package flag
import (
"fmt"
"os"
"sort"
"strconv"
)
......@@ -205,16 +206,34 @@ type allFlags struct {
var flags *allFlags
// VisitAll visits the flags, calling fn for each. It visits all flags, even those not set.
// sortFlags returns the flags as a slice in lexicographical sorted order.
func sortFlags(flags map[string]*Flag) []*Flag {
list := make(sort.StringArray, len(flags))
i := 0
for _, f := range flags {
list[i] = f.Name
i++
}
list.Sort()
result := make([]*Flag, len(list))
for i, name := range list {
result[i] = flags[name]
}
return result
}
// VisitAll visits the flags in lexicographical order, calling fn for each.
// It visits all flags, even those not set.
func VisitAll(fn func(*Flag)) {
for _, f := range flags.formal {
for _, f := range sortFlags(flags.formal) {
fn(f)
}
}
// Visit visits the flags, calling fn for each. It visits only those flags that have been set.
// Visit visits the flags in lexicographical order, calling fn for each.
// It visits only those flags that have been set.
func Visit(fn func(*Flag)) {
for _, f := range flags.actual {
for _, f := range sortFlags(flags.actual) {
fn(f)
}
}
......@@ -260,7 +279,9 @@ var Usage = func() {
var panicOnError = false
func fail() {
// failf prints to standard error a formatted error and Usage, and then exits the program.
func failf(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
Usage()
if panicOnError {
panic("flag parse error")
......@@ -268,6 +289,7 @@ func fail() {
os.Exit(2)
}
// NFlag returns the number of flags that have been set.
func NFlag() int { return len(flags.actual) }
// Arg returns the i'th command-line argument. Arg(0) is the first remaining argument
......@@ -415,8 +437,7 @@ func (f *allFlags) parseOne() (ok bool) {
}
name := s[num_minuses:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' {
fmt.Fprintln(os.Stderr, "bad flag syntax:", s)
fail()
failf("bad flag syntax: %s\n", s)
}
// it's a flag. does it have an argument?
......@@ -434,14 +455,12 @@ func (f *allFlags) parseOne() (ok bool) {
m := flags.formal
flag, alreadythere := m[name] // BUG
if !alreadythere {
fmt.Fprintf(os.Stderr, "flag provided but not defined: -%s\n", name)
fail()
failf("flag provided but not defined: -%s\n", name)
}
if fv, ok := flag.Value.(*boolValue); ok { // special case: doesn't need an arg
if has_value {
if !fv.Set(value) {
fmt.Fprintf(os.Stderr, "invalid boolean value %q for flag: -%s\n", value, name)
fail()
failf("invalid boolean value %q for flag: -%s\n", value, name)
}
} else {
fv.Set("true")
......@@ -454,13 +473,11 @@ func (f *allFlags) parseOne() (ok bool) {
value, f.args = f.args[0], f.args[1:]
}
if !has_value {
fmt.Fprintf(os.Stderr, "flag needs an argument: -%s\n", name)
fail()
failf("flag needs an argument: -%s\n", name)
}
ok = flag.Value.Set(value)
if !ok {
fmt.Fprintf(os.Stderr, "invalid value %q for flag: -%s\n", value, name)
fail()
failf("invalid value %q for flag: -%s\n", value, name)
}
}
flags.actual[name] = flag
......
......@@ -8,6 +8,7 @@ import (
. "flag"
"fmt"
"os"
"sort"
"testing"
)
......@@ -77,6 +78,12 @@ func TestEverything(t *testing.T) {
t.Log(k, *v)
}
}
// Now test they're visited in sort order.
var flagNames []string
Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) })
if !sort.StringsAreSorted(flagNames) {
t.Errorf("flag names not sorted: %v", flagNames)
}
}
func TestUsage(t *testing.T) {
......
......@@ -107,7 +107,7 @@ func (f *fmt) writePadding(n int, padding []byte) {
}
// Append b to f.buf, padded on left (w > 0) or right (w < 0 or f.minus)
// clear flags aftewards.
// clear flags afterwards.
func (f *fmt) pad(b []byte) {
var padding []byte
var left, right int
......@@ -124,7 +124,7 @@ func (f *fmt) pad(b []byte) {
}
// append s to buf, padded on left (w > 0) or right (w < 0 or f.minus).
// clear flags aftewards.
// clear flags afterwards.
func (f *fmt) padString(s string) {
var padding []byte
var left, right int
......
......@@ -35,10 +35,15 @@ type ScanState interface {
ReadRune() (rune int, size int, err os.Error)
// UnreadRune causes the next call to ReadRune to return the same rune.
UnreadRune() os.Error
// Token returns the next space-delimited token from the input. If
// a width has been specified, the returned token will be no longer
// than the width.
Token() (token string, err os.Error)
// Token skips space in the input if skipSpace is true, then returns the
// run of Unicode code points c satisfying f(c). If f is nil,
// !unicode.IsSpace(c) is used; that is, the token will hold non-space
// characters. Newlines are treated as space unless the scan operation
// is Scanln, Fscanln or Sscanln, in which case a newline is treated as
// EOF. The returned slice points to shared data that may be overwritten
// by the next call to Token, a call to a Scan function using the ScanState
// as input, or when the calling Scan method returns.
Token(skipSpace bool, f func(int) bool) (token []byte, err os.Error)
// Width returns the value of the width option and whether it has been set.
// The unit is Unicode code points.
Width() (wid int, ok bool)
......@@ -134,7 +139,7 @@ type scanError struct {
err os.Error
}
const EOF = -1
const eof = -1
// ss is the internal implementation of ScanState.
type ss struct {
......@@ -202,7 +207,7 @@ func (s *ss) getRune() (rune int) {
rune, _, err := s.ReadRune()
if err != nil {
if err == os.EOF {
return EOF
return eof
}
s.error(err)
}
......@@ -214,7 +219,7 @@ func (s *ss) getRune() (rune int) {
// syntax error.
func (s *ss) mustReadRune() (rune int) {
rune = s.getRune()
if rune == EOF {
if rune == eof {
s.error(io.ErrUnexpectedEOF)
}
return
......@@ -238,7 +243,7 @@ func (s *ss) errorString(err string) {
panic(scanError{os.ErrorString(err)})
}
func (s *ss) Token() (tok string, err os.Error) {
func (s *ss) Token(skipSpace bool, f func(int) bool) (tok []byte, err os.Error) {
defer func() {
if e := recover(); e != nil {
if se, ok := e.(scanError); ok {
......@@ -248,10 +253,19 @@ func (s *ss) Token() (tok string, err os.Error) {
}
}
}()
tok = s.token()
if f == nil {
f = notSpace
}
s.buf.Reset()
tok = s.token(skipSpace, f)
return
}
// notSpace is the default scanning function used in Token.
func notSpace(r int) bool {
return !unicode.IsSpace(r)
}
// readRune is a structure to enable reading UTF-8 encoded code points
// from an io.Reader. It is used if the Reader given to the scanner does
// not already implement io.RuneReader.
......@@ -364,7 +378,7 @@ func (s *ss) free(old ssave) {
func (s *ss) skipSpace(stopAtNewline bool) {
for {
rune := s.getRune()
if rune == EOF {
if rune == eof {
return
}
if rune == '\n' {
......@@ -384,24 +398,27 @@ func (s *ss) skipSpace(stopAtNewline bool) {
}
}
// token returns the next space-delimited string from the input. It
// skips white space. For Scanln, it stops at newlines. For Scan,
// newlines are treated as spaces.
func (s *ss) token() string {
s.skipSpace(false)
func (s *ss) token(skipSpace bool, f func(int) bool) []byte {
if skipSpace {
s.skipSpace(false)
}
// read until white space or newline
for {
rune := s.getRune()
if rune == EOF {
if rune == eof {
break
}
if unicode.IsSpace(rune) {
if !f(rune) {
s.UnreadRune()
break
}
s.buf.WriteRune(rune)
}
return s.buf.String()
return s.buf.Bytes()
}
// typeError indicates that the type of the operand did not match the format
......@@ -416,7 +433,7 @@ var boolError = os.ErrorString("syntax error scanning boolean")
// If accept is true, it puts the character into the input token.
func (s *ss) consume(ok string, accept bool) bool {
rune := s.getRune()
if rune == EOF {
if rune == eof {
return false
}
if strings.IndexRune(ok, rune) >= 0 {
......@@ -425,7 +442,7 @@ func (s *ss) consume(ok string, accept bool) bool {
}
return true
}
if rune != EOF && accept {
if rune != eof && accept {
s.UnreadRune()
}
return false
......@@ -434,7 +451,7 @@ func (s *ss) consume(ok string, accept bool) bool {
// peek reports whether the next character is in the ok string, without consuming it.
func (s *ss) peek(ok string) bool {
rune := s.getRune()
if rune != EOF {
if rune != eof {
s.UnreadRune()
}
return strings.IndexRune(ok, rune) >= 0
......@@ -729,7 +746,7 @@ func (s *ss) convertString(verb int) (str string) {
case 'x':
str = s.hexString()
default:
str = s.token() // %s and %v just return the next word
str = string(s.token(true, notSpace)) // %s and %v just return the next word
}
// Empty strings other than with %q are not OK.
if len(str) == 0 && verb != 'q' && s.maxWid > 0 {
......@@ -797,7 +814,7 @@ func (s *ss) hexDigit(digit int) int {
// There must be either two hexadecimal digits or a space character in the input.
func (s *ss) hexByte() (b byte, ok bool) {
rune1 := s.getRune()
if rune1 == EOF {
if rune1 == eof {
return
}
if unicode.IsSpace(rune1) {
......@@ -953,7 +970,7 @@ func (s *ss) doScan(a []interface{}) (numProcessed int, err os.Error) {
if !s.nlIsSpace {
for {
rune := s.getRune()
if rune == '\n' || rune == EOF {
if rune == '\n' || rune == eof {
break
}
if !unicode.IsSpace(rune) {
......@@ -993,7 +1010,7 @@ func (s *ss) advance(format string) (i int) {
// There was space in the format, so there should be space (EOF)
// in the input.
inputc := s.getRune()
if inputc == EOF {
if inputc == eof {
return
}
if !unicode.IsSpace(inputc) {
......
......@@ -88,14 +88,15 @@ type FloatTest struct {
type Xs string
func (x *Xs) Scan(state ScanState, verb int) os.Error {
tok, err := state.Token()
tok, err := state.Token(true, func(r int) bool { return r == verb })
if err != nil {
return err
}
if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(tok) {
s := string(tok)
if !regexp.MustCompile("^" + string(verb) + "+$").MatchString(s) {
return os.ErrorString("syntax error for xs")
}
*x = Xs(tok)
*x = Xs(s)
return nil
}
......@@ -113,9 +114,11 @@ func (s *IntString) Scan(state ScanState, verb int) os.Error {
return err
}
if _, err := Fscan(state, &s.s); err != nil {
tok, err := state.Token(true, nil)
if err != nil {
return err
}
s.s = string(tok)
return nil
}
......@@ -331,7 +334,7 @@ var multiTests = []ScanfMultiTest{
{"%c%c%c", "2\u50c2X", args(&i, &j, &k), args('2', '\u50c2', 'X'), ""},
// Custom scanners.
{"%2e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""},
{"%e%f", "eefffff", args(&x, &y), args(Xs("ee"), Xs("fffff")), ""},
{"%4v%s", "12abcd", args(&z, &s), args(IntString{12, "ab"}, "cd"), ""},
// Errors
......@@ -476,22 +479,12 @@ func verifyInf(str string, t *testing.T) {
}
}
func TestInf(t *testing.T) {
for _, s := range []string{"inf", "+inf", "-inf", "INF", "-INF", "+INF", "Inf", "-Inf", "+Inf"} {
verifyInf(s, t)
}
}
// TODO: there's no conversion from []T to ...T, but we can fake it. These
// functions do the faking. We index the table by the length of the param list.
var fscanf = []func(io.Reader, string, []interface{}) (int, os.Error){
0: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f) },
1: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0]) },
2: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1]) },
3: func(r io.Reader, f string, i []interface{}) (int, os.Error) { return Fscanf(r, f, i[0], i[1], i[2]) },
}
func testScanfMulti(name string, t *testing.T) {
sliceType := reflect.Typeof(make([]interface{}, 1)).(*reflect.SliceType)
for _, test := range multiTests {
......@@ -501,7 +494,7 @@ func testScanfMulti(name string, t *testing.T) {
} else {
r = newReader(test.text)
}
n, err := fscanf[len(test.in)](r, test.format, test.in)
n, err := Fscanf(r, test.format, test.in...)
if err != nil {
if test.err == "" {
t.Errorf("got error scanning (%q, %q): %q", test.format, test.text, err)
......@@ -830,12 +823,12 @@ func testScanInts(t *testing.T, scan func(*RecursiveInt, *bytes.Buffer) os.Error
i := 1
for ; r != nil; r = r.next {
if r.i != i {
t.Fatal("bad scan: expected %d got %d", i, r.i)
t.Fatalf("bad scan: expected %d got %d", i, r.i)
}
i++
}
if i-1 != intCount {
t.Fatal("bad scan count: expected %d got %d", intCount, i-1)
t.Fatalf("bad scan count: expected %d got %d", intCount, i-1)
}
}
......
......@@ -602,12 +602,12 @@ type (
Else Stmt // else branch; or nil
}
// A CaseClause represents a case of an expression switch statement.
// A CaseClause represents a case of an expression or type switch statement.
CaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Values []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
Case token.Pos // position of "case" or "default" keyword
List []Expr // list of expressions or types; nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// A SwitchStmt node represents an expression switch statement.
......@@ -618,20 +618,12 @@ type (
Body *BlockStmt // CaseClauses only
}
// A TypeCaseClause represents a case of a type switch statement.
TypeCaseClause struct {
Case token.Pos // position of "case" or "default" keyword
Types []Expr // nil means default case
Colon token.Pos // position of ":"
Body []Stmt // statement list; or nil
}
// An TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword
Init Stmt // initalization statement; or nil
Assign Stmt // x := y.(type)
Body *BlockStmt // TypeCaseClauses only
Assign Stmt // x := y.(type) or y.(type)
Body *BlockStmt // CaseClauses only
}
// A CommClause node represents a case of a select statement.
......@@ -687,7 +679,6 @@ func (s *BlockStmt) Pos() token.Pos { return s.Lbrace }
func (s *IfStmt) Pos() token.Pos { return s.If }
func (s *CaseClause) Pos() token.Pos { return s.Case }
func (s *SwitchStmt) Pos() token.Pos { return s.Switch }
func (s *TypeCaseClause) Pos() token.Pos { return s.Case }
func (s *TypeSwitchStmt) Pos() token.Pos { return s.Switch }
func (s *CommClause) Pos() token.Pos { return s.Case }
func (s *SelectStmt) Pos() token.Pos { return s.Select }
......@@ -734,13 +725,7 @@ func (s *CaseClause) End() token.Pos {
}
return s.Colon + 1
}
func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeCaseClause) End() token.Pos {
if n := len(s.Body); n > 0 {
return s.Body[n-1].End()
}
return s.Colon + 1
}
func (s *SwitchStmt) End() token.Pos { return s.Body.End() }
func (s *TypeSwitchStmt) End() token.Pos { return s.Body.End() }
func (s *CommClause) End() token.Pos {
if n := len(s.Body); n > 0 {
......@@ -772,7 +757,6 @@ func (s *BlockStmt) stmtNode() {}
func (s *IfStmt) stmtNode() {}
func (s *CaseClause) stmtNode() {}
func (s *SwitchStmt) stmtNode() {}
func (s *TypeCaseClause) stmtNode() {}
func (s *TypeSwitchStmt) stmtNode() {}
func (s *CommClause) stmtNode() {}
func (s *SelectStmt) stmtNode() {}
......@@ -937,11 +921,13 @@ func (d *FuncDecl) declNode() {}
// via Doc and Comment fields.
//
type File struct {
Doc *CommentGroup // associated documentation; or nil
Package token.Pos // position of "package" keyword
Name *Ident // package name
Decls []Decl // top-level declarations; or nil
Comments []*CommentGroup // list of all comments in the source file
Doc *CommentGroup // associated documentation; or nil
Package token.Pos // position of "package" keyword
Name *Ident // package name
Decls []Decl // top-level declarations; or nil
Scope *Scope // package scope
Unresolved []*Ident // unresolved global identifiers
Comments []*CommentGroup // list of all comments in the source file
}
......@@ -959,7 +945,7 @@ func (f *File) End() token.Pos {
//
type Package struct {
Name string // package name
Scope *Scope // package scope; or nil
Scope *Scope // package scope
Files map[string]*File // Go source files by filename
}
......
......@@ -425,5 +425,6 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File {
}
}
return &File{doc, pos, NewIdent(pkg.Name), decls, comments}
// TODO(gri) need to compute pkgScope and unresolved identifiers!
return &File{doc, pos, NewIdent(pkg.Name), decls, nil, nil, comments}
}
......@@ -30,15 +30,19 @@ func NotNilFilter(_ string, value reflect.Value) bool {
// Fprint prints the (sub-)tree starting at AST node x to w.
// If fset != nil, position information is interpreted relative
// to that file set. Otherwise positions are printed as integer
// values (file set specific offsets).
//
// A non-nil FieldFilter f may be provided to control the output:
// struct fields for which f(fieldname, fieldvalue) is true are
// are printed; all others are filtered from the output.
//
func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
func Fprint(w io.Writer, fset *token.FileSet, x interface{}, f FieldFilter) (n int, err os.Error) {
// setup printer
p := printer{
output: w,
fset: fset,
filter: f,
ptrmap: make(map[interface{}]int),
last: '\n', // force printing of line number on first line
......@@ -65,14 +69,15 @@ func Fprint(w io.Writer, x interface{}, f FieldFilter) (n int, err os.Error) {
// Print prints x to standard output, skipping nil fields.
// Print(x) is the same as Fprint(os.Stdout, x, NotNilFilter).
func Print(x interface{}) (int, os.Error) {
return Fprint(os.Stdout, x, NotNilFilter)
// Print(fset, x) is the same as Fprint(os.Stdout, fset, x, NotNilFilter).
func Print(fset *token.FileSet, x interface{}) (int, os.Error) {
return Fprint(os.Stdout, fset, x, NotNilFilter)
}
type printer struct {
output io.Writer
fset *token.FileSet
filter FieldFilter
ptrmap map[interface{}]int // *reflect.PtrValue -> line number
written int // number of bytes written to output
......@@ -137,16 +142,6 @@ func (p *printer) printf(format string, args ...interface{}) {
// probably be in a different package.
func (p *printer) print(x reflect.Value) {
// Note: This test is only needed because AST nodes
// embed a token.Position, and thus all of them
// understand the String() method (but it only
// applies to the Position field).
// TODO: Should reconsider this AST design decision.
if pos, ok := x.Interface().(token.Position); ok {
p.printf("%s", pos)
return
}
if !NotNilFilter("", x) {
p.printf("nil")
return
......@@ -163,6 +158,7 @@ func (p *printer) print(x reflect.Value) {
p.print(key)
p.printf(": ")
p.print(v.Elem(key))
p.printf("\n")
}
p.indent--
p.printf("}")
......@@ -212,6 +208,11 @@ func (p *printer) print(x reflect.Value) {
p.printf("}")
default:
p.printf("%v", x.Interface())
value := x.Interface()
// position values can be printed nicely if we have a file set
if pos, ok := value.(token.Pos); ok && p.fset != nil {
value = p.fset.Position(pos)
}
p.printf("%v", value)
}
}
......@@ -2,31 +2,31 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements scopes, the objects they contain,
// and object types.
// This file implements scopes and the objects they contain.
package ast
import (
"bytes"
"fmt"
"go/token"
)
// A Scope maintains the set of named language entities declared
// in the scope and a link to the immediately surrounding (outer)
// scope.
//
type Scope struct {
Outer *Scope
Objects []*Object // in declaration order
// Implementation note: In some cases (struct fields,
// function parameters) we need the source order of
// variables. Thus for now, we store scope entries
// in a linear list. If scopes become very large
// (say, for packages), we may need to change this
// to avoid slow lookups.
Objects map[string]*Object
}
// NewScope creates a new scope nested in the outer scope.
func NewScope(outer *Scope) *Scope {
const n = 4 // initial scope capacity, must be > 0
return &Scope{outer, make([]*Object, 0, n)}
const n = 4 // initial scope capacity
return &Scope{outer, make(map[string]*Object, n)}
}
......@@ -34,73 +34,108 @@ func NewScope(outer *Scope) *Scope {
// found in scope s, otherwise it returns nil. Outer scopes
// are ignored.
//
// Lookup always returns nil if name is "_", even if the scope
// contains objects with that name.
//
func (s *Scope) Lookup(name string) *Object {
if name != "_" {
for _, obj := range s.Objects {
if obj.Name == name {
return obj
}
}
}
return nil
return s.Objects[name]
}
// Insert attempts to insert a named object into the scope s.
// If the scope does not contain an object with that name yet
// or if the object is named "_", Insert inserts the object
// and returns it. Otherwise, Insert leaves the scope unchanged
// and returns the object found in the scope instead.
// If the scope does not contain an object with that name yet,
// Insert inserts the object and returns it. Otherwise, Insert
// leaves the scope unchanged and returns the object found in
// the scope instead.
//
func (s *Scope) Insert(obj *Object) *Object {
alt := s.Lookup(obj.Name)
if alt == nil {
s.append(obj)
func (s *Scope) Insert(obj *Object) (alt *Object) {
if alt = s.Objects[obj.Name]; alt == nil {
s.Objects[obj.Name] = obj
alt = obj
}
return alt
return
}
func (s *Scope) append(obj *Object) {
s.Objects = append(s.Objects, obj)
// Debugging support
func (s *Scope) String() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "scope %p {", s)
if s != nil && len(s.Objects) > 0 {
fmt.Fprintln(&buf)
for _, obj := range s.Objects {
fmt.Fprintf(&buf, "\t%s %s\n", obj.Kind, obj.Name)
}
}
fmt.Fprintf(&buf, "}\n")
return buf.String()
}
// ----------------------------------------------------------------------------
// Objects
// An Object describes a language entity such as a package,
// constant, type, variable, or function (incl. methods).
// An Object describes a named language entity such as a package,
// constant, type, variable, function (incl. methods), or label.
//
type Object struct {
Kind Kind
Name string // declared name
Type *Type
Decl interface{} // corresponding Field, XxxSpec or FuncDecl
N int // value of iota for this declaration
Kind ObjKind
Name string // declared name
Decl interface{} // corresponding Field, XxxSpec, FuncDecl, or LabeledStmt; or nil
Type interface{} // place holder for type information; may be nil
}
// NewObj creates a new object of a given kind and name.
func NewObj(kind Kind, name string) *Object {
func NewObj(kind ObjKind, name string) *Object {
return &Object{Kind: kind, Name: name}
}
// Kind describes what an object represents.
type Kind int
// Pos computes the source position of the declaration of an object name.
// The result may be an invalid position if it cannot be computed
// (obj.Decl may be nil or not correct).
func (obj *Object) Pos() token.Pos {
name := obj.Name
switch d := obj.Decl.(type) {
case *Field:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *ValueSpec:
for _, n := range d.Names {
if n.Name == name {
return n.Pos()
}
}
case *TypeSpec:
if d.Name.Name == name {
return d.Name.Pos()
}
case *FuncDecl:
if d.Name.Name == name {
return d.Name.Pos()
}
case *LabeledStmt:
if d.Label.Name == name {
return d.Label.Pos()
}
}
return token.NoPos
}
// ObKind describes what an object represents.
type ObjKind int
// The list of possible Object kinds.
const (
Bad Kind = iota // for error handling
Pkg // package
Con // constant
Typ // type
Var // variable
Fun // function or method
Bad ObjKind = iota // for error handling
Pkg // package
Con // constant
Typ // type
Var // variable
Fun // function or method
Lbl // label
)
......@@ -111,132 +146,8 @@ var objKindStrings = [...]string{
Typ: "type",
Var: "var",
Fun: "func",
Lbl: "label",
}
func (kind Kind) String() string { return objKindStrings[kind] }
// IsExported returns whether obj is exported.
func (obj *Object) IsExported() bool { return IsExported(obj.Name) }
// ----------------------------------------------------------------------------
// Types
// A Type represents a Go type.
type Type struct {
Form Form
Obj *Object // corresponding type name, or nil
Scope *Scope // fields and methods, always present
N uint // basic type id, array length, number of function results, or channel direction
Key, Elt *Type // map key and array, pointer, slice, map or channel element
Params *Scope // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
Expr Expr // corresponding AST expression
}
// NewType creates a new type of a given form.
func NewType(form Form) *Type {
return &Type{Form: form, Scope: NewScope(nil)}
}
// Form describes the form of a type.
type Form int
// The list of possible type forms.
const (
BadType Form = iota // for error handling
Unresolved // type not fully setup
Basic
Array
Struct
Pointer
Function
Method
Interface
Slice
Map
Channel
Tuple
)
var formStrings = [...]string{
BadType: "badType",
Unresolved: "unresolved",
Basic: "basic",
Array: "array",
Struct: "struct",
Pointer: "pointer",
Function: "function",
Method: "method",
Interface: "interface",
Slice: "slice",
Map: "map",
Channel: "channel",
Tuple: "tuple",
}
func (form Form) String() string { return formStrings[form] }
// The list of basic type id's.
const (
Bool = iota
Byte
Uint
Int
Float
Complex
Uintptr
String
Uint8
Uint16
Uint32
Uint64
Int8
Int16
Int32
Int64
Float32
Float64
Complex64
Complex128
// TODO(gri) ideal types are missing
)
var BasicTypes = map[uint]string{
Bool: "bool",
Byte: "byte",
Uint: "uint",
Int: "int",
Float: "float",
Complex: "complex",
Uintptr: "uintptr",
String: "string",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
}
func (kind ObjKind) String() string { return objKindStrings[kind] }
......@@ -234,7 +234,7 @@ func Walk(v Visitor, node Node) {
}
case *CaseClause:
walkExprList(v, n.Values)
walkExprList(v, n.List)
walkStmtList(v, n.Body)
case *SwitchStmt:
......@@ -246,12 +246,6 @@ func Walk(v Visitor, node Node) {
}
Walk(v, n.Body)
case *TypeCaseClause:
for _, x := range n.Types {
Walk(v, x)
}
walkStmtList(v, n.Body)
case *TypeSwitchStmt:
if n.Init != nil {
Walk(v, n.Init)
......
......@@ -14,7 +14,7 @@ import (
"io"
"io/ioutil"
"os"
pathutil "path"
"path/filepath"
)
......@@ -198,7 +198,7 @@ func ParseDir(fset *token.FileSet, path string, filter func(*os.FileInfo) bool,
for i := 0; i < len(list); i++ {
d := &list[i]
if filter == nil || filter(d) {
filenames[n] = pathutil.Join(path, d.Name)
filenames[n] = filepath.Join(path, d.Name)
n++
}
}
......
......@@ -21,6 +21,7 @@ var illegalInputs = []interface{}{
`package p; func f() { if /* should have condition */ {} };`,
`package p; func f() { if ; /* should have condition */ {} };`,
`package p; func f() { if f(); /* should have condition */ {} };`,
`package p; const c; /* should have constant value */`,
}
......@@ -73,7 +74,7 @@ var validFiles = []string{
func TestParse3(t *testing.T) {
for _, filename := range validFiles {
_, err := ParseFile(fset, filename, nil, 0)
_, err := ParseFile(fset, filename, nil, DeclarationErrors)
if err != nil {
t.Errorf("ParseFile(%s): %v", filename, err)
}
......
......@@ -12,7 +12,7 @@ import (
"go/token"
"io"
"os"
"path"
"path/filepath"
"runtime"
"tabwriter"
)
......@@ -94,22 +94,23 @@ type printer struct {
// written using writeItem.
last token.Position
// HTML support
lastTaggedLine int // last line for which a line tag was written
// The list of all source comments, in order of appearance.
comments []*ast.CommentGroup // may be nil
cindex int // current comment index
useNodeComments bool // if not set, ignore lead and line comments of nodes
// Cache of already computed node sizes.
nodeSizes map[ast.Node]int
}
func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet) {
func (p *printer) init(output io.Writer, cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) {
p.output = output
p.Config = *cfg
p.fset = fset
p.errors = make(chan os.Error)
p.buffer = make([]whiteSpace, 0, 16) // whitespace sequences are short
p.nodeSizes = nodeSizes
}
......@@ -244,7 +245,7 @@ func (p *printer) writeItem(pos token.Position, data []byte) {
}
if debug {
// do not update p.pos - use write0
_, filename := path.Split(pos.Filename)
_, filename := filepath.Split(pos.Filename)
p.write0([]byte(fmt.Sprintf("[%s:%d:%d]", filename, pos.Line, pos.Column)))
}
p.write(data)
......@@ -994,13 +995,8 @@ type Config struct {
}
// Fprint "pretty-prints" an AST node to output and returns the number
// of bytes written and an error (if any) for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *ast.File, or assignment-compatible to ast.Expr,
// ast.Decl, ast.Spec, or ast.Stmt.
//
func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
// fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node interface{}, nodeSizes map[ast.Node]int) (int, os.Error) {
// redirect output through a trimmer to eliminate trailing whitespace
// (Input to a tabwriter must be untrimmed since trailing tabs provide
// formatting information. The tabwriter could provide trimming
......@@ -1029,7 +1025,7 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
// setup printer and print node
var p printer
p.init(output, cfg, fset)
p.init(output, cfg, fset, nodeSizes)
go func() {
switch n := node.(type) {
case ast.Expr:
......@@ -1076,6 +1072,17 @@ func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{
}
// Fprint "pretty-prints" an AST node to output and returns the number
// of bytes written and an error (if any) for a given configuration cfg.
// Position information is interpreted relative to the file set fset.
// The node type must be *ast.File, or assignment-compatible to ast.Expr,
// ast.Decl, ast.Spec, or ast.Stmt.
//
func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) (int, os.Error) {
return cfg.fprint(output, fset, node, make(map[ast.Node]int))
}
// Fprint "pretty-prints" an AST node to output.
// It calls Config.Fprint with default settings.
//
......
......@@ -11,8 +11,9 @@ import (
"go/ast"
"go/parser"
"go/token"
"path"
"path/filepath"
"testing"
"time"
)
......@@ -45,7 +46,7 @@ const (
)
func check(t *testing.T, source, golden string, mode checkMode) {
func runcheck(t *testing.T, source, golden string, mode checkMode) {
// parse source
prog, err := parser.ParseFile(fset, source, nil, parser.ParseComments)
if err != nil {
......@@ -109,6 +110,32 @@ func check(t *testing.T, source, golden string, mode checkMode) {
}
func check(t *testing.T, source, golden string, mode checkMode) {
// start a timer to produce a time-out signal
tc := make(chan int)
go func() {
time.Sleep(20e9) // plenty of a safety margin, even for very slow machines
tc <- 0
}()
// run the test
cc := make(chan int)
go func() {
runcheck(t, source, golden, mode)
cc <- 0
}()
// wait for the first finisher
select {
case <-tc:
// test running past time out
t.Errorf("%s: running too slowly", source)
case <-cc:
// test finished within alloted time margin
}
}
type entry struct {
source, golden string
mode checkMode
......@@ -124,13 +151,14 @@ var data = []entry{
{"expressions.input", "expressions.raw", rawFormat},
{"declarations.input", "declarations.golden", 0},
{"statements.input", "statements.golden", 0},
{"slow.input", "slow.golden", 0},
}
func TestFiles(t *testing.T) {
for _, e := range data {
source := path.Join(dataDir, e.source)
golden := path.Join(dataDir, e.golden)
source := filepath.Join(dataDir, e.source)
golden := filepath.Join(dataDir, e.golden)
check(t, source, golden, e.mode)
// TODO(gri) check that golden is idempotent
//check(t, golden, golden, e.mode);
......
......@@ -224,11 +224,7 @@ func _() {
_ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0}
_ = struct {
s struct {
int
}
}{struct{ int }{0}} // compositeLit context not propagated => multiLine result
_ = struct{ s struct{ int } }{struct{ int }{0}}
}
......
......@@ -224,7 +224,7 @@ func _() {
_ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0}
_ = struct{ s struct { int } }{struct{ int}{0}} // compositeLit context not propagated => multiLine result
_ = struct{ s struct { int } }{struct{ int}{0} }
}
......
......@@ -224,11 +224,7 @@ func _() {
_ = struct{ x int }{0}
_ = struct{ x, y, z int }{0, 1, 2}
_ = struct{ int }{0}
_ = struct {
s struct {
int
}
}{struct{ int }{0}} // compositeLit context not propagated => multiLine result
_ = struct{ s struct{ int } }{struct{ int }{0}}
}
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deepequal_test
import (
"testing"
"google3/spam/archer/frontend/deepequal"
)
func TestTwoNilValues(t *testing.T) {
if err := deepequal.Check(nil, nil); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
type Foo struct {
bar *Bar
bang *Bar
}
type Bar struct {
baz *Baz
foo []*Foo
}
type Baz struct {
entries map[int]interface{}
whatever string
}
func newFoo() *Foo {
return &Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
42: &Foo{},
21: &Bar{},
11: &Baz{whatever: "it's just a test"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
43: &Foo{},
22: &Bar{},
13: &Baz{whatever: "this is nuts"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
61: &Foo{},
71: &Bar{},
11: &Baz{whatever: "no, it's Go"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
0: &Foo{},
-2: &Bar{},
-11: &Baz{whatever: "we need to go deeper"}}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
-2: &Foo{},
-5: &Bar{},
-7: &Baz{whatever: "are you serious?"}}}},
bang: &Bar{foo: []*Foo{}}},
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
-100: &Foo{},
50: &Bar{},
20: &Baz{whatever: "na, not really ..."}}}},
bang: &Bar{foo: []*Foo{}}}}}}}}},
&Foo{bar: &Bar{baz: &Baz{
entries: map[int]interface{}{
2: &Foo{},
1: &Bar{},
-1: &Baz{whatever: "... it's just a test."}}}},
bang: &Bar{foo: []*Foo{}}}}}}}}}
}
func TestElaborate(t *testing.T) {
a := newFoo()
b := newFoo()
if err := deepequal.Check(a, b); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package deepequal_test
import (
"testing"
"google3/spam/archer/frontend/deepequal"
)
func TestTwoNilValues(t *testing.T) {
if err := deepequal.Check(nil, nil); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
type Foo struct {
bar *Bar
bang *Bar
}
type Bar struct {
baz *Baz
foo []*Foo
}
type Baz struct {
entries map[int]interface{}
whatever string
}
func newFoo() (*Foo) {
return &Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
42: &Foo{},
21: &Bar{},
11: &Baz{ whatever: "it's just a test" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
43: &Foo{},
22: &Bar{},
13: &Baz{ whatever: "this is nuts" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
61: &Foo{},
71: &Bar{},
11: &Baz{ whatever: "no, it's Go" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
0: &Foo{},
-2: &Bar{},
-11: &Baz{ whatever: "we need to go deeper" }}}},
bang: &Bar{foo: []*Foo{
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
-2: &Foo{},
-5: &Bar{},
-7: &Baz{ whatever: "are you serious?" }}}},
bang: &Bar{foo: []*Foo{}}},
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
-100: &Foo{},
50: &Bar{},
20: &Baz{ whatever: "na, not really ..." }}}},
bang: &Bar{foo: []*Foo{}}}}}}}}},
&Foo{bar: &Bar{ baz: &Baz{
entries: map[int]interface{}{
2: &Foo{},
1: &Bar{},
-1: &Baz{ whatever: "... it's just a test." }}}},
bang: &Bar{foo: []*Foo{}}}}}}}}}
}
func TestElaborate(t *testing.T) {
a := newFoo()
b := newFoo()
if err := deepequal.Check(a, b); err != nil {
t.Errorf("expected nil, saw %v", err)
}
}
......@@ -23,7 +23,7 @@ package scanner
import (
"bytes"
"go/token"
"path"
"path/filepath"
"strconv"
"unicode"
"utf8"
......@@ -118,7 +118,7 @@ func (S *Scanner) Init(file *token.File, src []byte, err ErrorHandler, mode uint
panic("file size does not match src len")
}
S.file = file
S.dir, _ = path.Split(file.Name())
S.dir, _ = filepath.Split(file.Name())
S.src = src
S.err = err
S.mode = mode
......@@ -177,13 +177,13 @@ var prefix = []byte("//line ")
func (S *Scanner) interpretLineComment(text []byte) {
if bytes.HasPrefix(text, prefix) {
// get filename and line number, if any
if i := bytes.Index(text, []byte{':'}); i > 0 {
if i := bytes.LastIndex(text, []byte{':'}); i > 0 {
if line, err := strconv.Atoi(string(text[i+1:])); err == nil && line > 0 {
// valid //line filename:line comment;
filename := path.Clean(string(text[len(prefix):i]))
if filename[0] != '/' {
filename := filepath.Clean(string(text[len(prefix):i]))
if !filepath.IsAbs(filename) {
// make filename relative to current directory
filename = path.Join(S.dir, filename)
filename = filepath.Join(S.dir, filename)
}
// update scanner position
S.file.AddLineInfo(S.lineOffset, filename, line-1) // -1 since comment applies to next line
......
......@@ -7,6 +7,8 @@ package scanner
import (
"go/token"
"os"
"path/filepath"
"runtime"
"testing"
)
......@@ -443,32 +445,41 @@ func TestSemis(t *testing.T) {
}
}
var segments = []struct {
type segment struct {
srcline string // a line of source text
filename string // filename for current token
line int // line number for current token
}{
}
var segments = []segment{
// exactly one token per line since the test consumes one token per segment
{" line1", "dir/TestLineComments", 1},
{"\nline2", "dir/TestLineComments", 2},
{"\nline3 //line File1.go:100", "dir/TestLineComments", 3}, // bad line comment, ignored
{"\nline4", "dir/TestLineComments", 4},
{"\n//line File1.go:100\n line100", "dir/File1.go", 100},
{"\n//line File2.go:200\n line200", "dir/File2.go", 200},
{" line1", filepath.Join("dir", "TestLineComments"), 1},
{"\nline2", filepath.Join("dir", "TestLineComments"), 2},
{"\nline3 //line File1.go:100", filepath.Join("dir", "TestLineComments"), 3}, // bad line comment, ignored
{"\nline4", filepath.Join("dir", "TestLineComments"), 4},
{"\n//line File1.go:100\n line100", filepath.Join("dir", "File1.go"), 100},
{"\n//line File2.go:200\n line200", filepath.Join("dir", "File2.go"), 200},
{"\n//line :1\n line1", "dir", 1},
{"\n//line foo:42\n line42", "dir/foo", 42},
{"\n //line foo:42\n line44", "dir/foo", 44}, // bad line comment, ignored
{"\n//line foo 42\n line46", "dir/foo", 46}, // bad line comment, ignored
{"\n//line foo:42 extra text\n line48", "dir/foo", 48}, // bad line comment, ignored
{"\n//line /bar:42\n line42", "/bar", 42},
{"\n//line ./foo:42\n line42", "dir/foo", 42},
{"\n//line a/b/c/File1.go:100\n line100", "dir/a/b/c/File1.go", 100},
{"\n//line foo:42\n line42", filepath.Join("dir", "foo"), 42},
{"\n //line foo:42\n line44", filepath.Join("dir", "foo"), 44}, // bad line comment, ignored
{"\n//line foo 42\n line46", filepath.Join("dir", "foo"), 46}, // bad line comment, ignored
{"\n//line foo:42 extra text\n line48", filepath.Join("dir", "foo"), 48}, // bad line comment, ignored
{"\n//line /bar:42\n line42", string(filepath.Separator) + "bar", 42},
{"\n//line ./foo:42\n line42", filepath.Join("dir", "foo"), 42},
{"\n//line a/b/c/File1.go:100\n line100", filepath.Join("dir", "a", "b", "c", "File1.go"), 100},
}
var winsegments = []segment{
{"\n//line c:\\dir\\File1.go:100\n line100", "c:\\dir\\File1.go", 100},
}
// Verify that comments of the form "//line filename:line" are interpreted correctly.
func TestLineComments(t *testing.T) {
if runtime.GOOS == "windows" {
segments = append(segments, winsegments...)
}
// make source
var src string
for _, e := range segments {
......@@ -477,7 +488,7 @@ func TestLineComments(t *testing.T) {
// verify scan
var S Scanner
file := fset.AddFile("dir/TestLineComments", fset.Base(), len(src))
file := fset.AddFile(filepath.Join("dir", "TestLineComments"), fset.Base(), len(src))
S.Init(file, []byte(src), nil, 0)
for _, s := range segments {
p, _, lit := S.Scan()
......
......@@ -2,15 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements scope support functions.
// DEPRECATED FILE - WILL GO AWAY EVENTUALLY.
//
// Scope handling is now done in go/parser.
// The functionality here is only present to
// keep the typechecker running for now.
package typechecker
import (
"fmt"
"go/ast"
"go/token"
)
import "go/ast"
func (tc *typechecker) openScope() *ast.Scope {
......@@ -24,52 +24,25 @@ func (tc *typechecker) closeScope() {
}
// objPos computes the source position of the declaration of an object name.
// Only required for error reporting, so doesn't have to be fast.
func objPos(obj *ast.Object) (pos token.Pos) {
switch d := obj.Decl.(type) {
case *ast.Field:
for _, n := range d.Names {
if n.Name == obj.Name {
return n.Pos()
}
}
case *ast.ValueSpec:
for _, n := range d.Names {
if n.Name == obj.Name {
return n.Pos()
}
}
case *ast.TypeSpec:
return d.Name.Pos()
case *ast.FuncDecl:
return d.Name.Pos()
}
if debug {
fmt.Printf("decl = %T\n", obj.Decl)
}
panic("unreachable")
}
// declInScope declares an object of a given kind and name in scope and sets the object's Decl and N fields.
// It returns the newly allocated object. If an object with the same name already exists in scope, an error
// is reported and the object is not inserted.
// (Objects with _ name are always inserted into a scope without errors, but they cannot be found.)
func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object {
func (tc *typechecker) declInScope(scope *ast.Scope, kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
obj := ast.NewObj(kind, name.Name)
obj.Decl = decl
obj.N = n
//obj.N = n
name.Obj = obj
if alt := scope.Insert(obj); alt != obj {
tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, objPos(alt))
if name.Name != "_" {
if alt := scope.Insert(obj); alt != obj {
tc.Errorf(name.Pos(), "%s already declared at %s", name.Name, tc.fset.Position(alt.Pos()).String())
}
}
return obj
}
// decl is the same as declInScope(tc.topScope, ...)
func (tc *typechecker) decl(kind ast.Kind, name *ast.Ident, decl interface{}, n int) *ast.Object {
func (tc *typechecker) decl(kind ast.ObjKind, name *ast.Ident, decl interface{}, n int) *ast.Object {
return tc.declInScope(tc.topScope, kind, name, decl, n)
}
......@@ -91,7 +64,7 @@ func (tc *typechecker) find(name *ast.Ident) (obj *ast.Object) {
// findField returns the object with the given name if visible in the type's scope.
// If no such object is found, an error is reported and a bad object is returned instead.
func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Object) {
func (tc *typechecker) findField(typ *Type, name *ast.Ident) (obj *ast.Object) {
// TODO(gri) This is simplistic at the moment and ignores anonymous fields.
obj = typ.Scope.Lookup(name.Name)
if obj == nil {
......@@ -100,20 +73,3 @@ func (tc *typechecker) findField(typ *ast.Type, name *ast.Ident) (obj *ast.Objec
}
return
}
// printScope prints the objects in a scope.
func printScope(scope *ast.Scope) {
fmt.Printf("scope %p {", scope)
if scope != nil && len(scope.Objects) > 0 {
fmt.Println()
for _, obj := range scope.Objects {
form := "void"
if obj.Type != nil {
form = obj.Type.Form.String()
}
fmt.Printf("\t%s\t%s\n", obj.Name, form)
}
}
fmt.Printf("}\n")
}
......@@ -7,7 +7,7 @@
package P1
const (
c1 /* ERROR "missing initializer" */
c1 = 0
c2 int = 0
c3, c4 = 0
)
......@@ -27,8 +27,11 @@ func (T) m1 /* ERROR "already declared" */ () {}
func (x *T) m2(u, x /* ERROR "already declared" */ int) {}
func (x *T) m3(a, b, c int) (u, x /* ERROR "already declared" */ int) {}
func (T) _(x, x /* ERROR "already declared" */ int) {}
func (T) _() (x, x /* ERROR "already declared" */ int) {}
// The following are disabled for now because the typechecker
// in in the process of being rewritten and cannot handle them
// at the moment
//func (T) _(x, x /* "already declared" */ int) {}
//func (T) _() (x, x /* "already declared" */ int) {}
//func (PT) _() {}
......
......@@ -7,5 +7,5 @@
package P4
const (
c0 /* ERROR "missing initializer" */
c0 = 0
)
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package typechecker
import "go/ast"
// A Type represents a Go type.
type Type struct {
Form Form
Obj *ast.Object // corresponding type name, or nil
Scope *ast.Scope // fields and methods, always present
N uint // basic type id, array length, number of function results, or channel direction
Key, Elt *Type // map key and array, pointer, slice, map or channel element
Params *ast.Scope // function (receiver, input and result) parameters, tuple expressions (results of function calls), or nil
Expr ast.Expr // corresponding AST expression
}
// NewType creates a new type of a given form.
func NewType(form Form) *Type {
return &Type{Form: form, Scope: ast.NewScope(nil)}
}
// Form describes the form of a type.
type Form int
// The list of possible type forms.
const (
BadType Form = iota // for error handling
Unresolved // type not fully setup
Basic
Array
Struct
Pointer
Function
Method
Interface
Slice
Map
Channel
Tuple
)
var formStrings = [...]string{
BadType: "badType",
Unresolved: "unresolved",
Basic: "basic",
Array: "array",
Struct: "struct",
Pointer: "pointer",
Function: "function",
Method: "method",
Interface: "interface",
Slice: "slice",
Map: "map",
Channel: "channel",
Tuple: "tuple",
}
func (form Form) String() string { return formStrings[form] }
// The list of basic type id's.
const (
Bool = iota
Byte
Uint
Int
Float
Complex
Uintptr
String
Uint8
Uint16
Uint32
Uint64
Int8
Int16
Int32
Int64
Float32
Float64
Complex64
Complex128
// TODO(gri) ideal types are missing
)
var BasicTypes = map[uint]string{
Bool: "bool",
Byte: "byte",
Uint: "uint",
Int: "int",
Float: "float",
Complex: "complex",
Uintptr: "uintptr",
String: "string",
Uint8: "uint8",
Uint16: "uint16",
Uint32: "uint32",
Uint64: "uint64",
Int8: "int8",
Int16: "int16",
Int32: "int32",
Int64: "int64",
Float32: "float32",
Float64: "float64",
Complex64: "complex64",
Complex128: "complex128",
}
......@@ -65,6 +65,7 @@ type typechecker struct {
fset *token.FileSet
scanner.ErrorVector
importer Importer
globals []*ast.Object // list of global objects
topScope *ast.Scope // current top-most scope
cyclemap map[*ast.Object]bool // for cycle detection
iota int // current value of iota
......@@ -94,7 +95,7 @@ phase 1: declare all global objects; also collect all function and method declar
- report global double declarations
phase 2: bind methods to their receiver base types
- received base types must be declared in the package, thus for
- receiver base types must be declared in the package, thus for
each method a corresponding (unresolved) type must exist
- report method double declarations and errors with base types
......@@ -142,16 +143,16 @@ func (tc *typechecker) checkPackage(pkg *ast.Package) {
}
// phase 3: resolve all global objects
// (note that objects with _ name are also in the scope)
tc.cyclemap = make(map[*ast.Object]bool)
for _, obj := range tc.topScope.Objects {
for _, obj := range tc.globals {
tc.resolve(obj)
}
assert(len(tc.cyclemap) == 0)
// 4: sequentially typecheck function and method bodies
for _, f := range funcs {
tc.checkBlock(f.Body.List, f.Name.Obj.Type)
ftype, _ := f.Name.Obj.Type.(*Type)
tc.checkBlock(f.Body.List, ftype)
}
pkg.Scope = tc.topScope
......@@ -183,11 +184,11 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
}
}
for _, name := range s.Names {
tc.decl(ast.Con, name, s, iota)
tc.globals = append(tc.globals, tc.decl(ast.Con, name, s, iota))
}
case token.VAR:
for _, name := range s.Names {
tc.decl(ast.Var, name, s, 0)
tc.globals = append(tc.globals, tc.decl(ast.Var, name, s, 0))
}
default:
panic("unreachable")
......@@ -196,9 +197,10 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
iota++
case *ast.TypeSpec:
obj := tc.decl(ast.Typ, s.Name, s, 0)
tc.globals = append(tc.globals, obj)
// give all type objects an unresolved type so
// that we can collect methods in the type scope
typ := ast.NewType(ast.Unresolved)
typ := NewType(Unresolved)
obj.Type = typ
typ.Obj = obj
default:
......@@ -208,7 +210,7 @@ func (tc *typechecker) declGlobal(global ast.Decl) {
case *ast.FuncDecl:
if d.Recv == nil {
tc.decl(ast.Fun, d.Name, d, 0)
tc.globals = append(tc.globals, tc.decl(ast.Fun, d.Name, d, 0))
}
default:
......@@ -239,8 +241,8 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
} else if obj.Kind != ast.Typ {
tc.Errorf(name.Pos(), "invalid receiver: %s is not a type", name.Name)
} else {
typ := obj.Type
assert(typ.Form == ast.Unresolved)
typ := obj.Type.(*Type)
assert(typ.Form == Unresolved)
scope = typ.Scope
}
}
......@@ -261,7 +263,7 @@ func (tc *typechecker) bindMethod(method *ast.FuncDecl) {
func (tc *typechecker) resolve(obj *ast.Object) {
// check for declaration cycles
if tc.cyclemap[obj] {
tc.Errorf(objPos(obj), "illegal cycle in declaration of %s", obj.Name)
tc.Errorf(obj.Pos(), "illegal cycle in declaration of %s", obj.Name)
obj.Kind = ast.Bad
return
}
......@@ -271,7 +273,7 @@ func (tc *typechecker) resolve(obj *ast.Object) {
}()
// resolve non-type objects
typ := obj.Type
typ, _ := obj.Type.(*Type)
if typ == nil {
switch obj.Kind {
case ast.Bad:
......@@ -282,12 +284,12 @@ func (tc *typechecker) resolve(obj *ast.Object) {
case ast.Var:
tc.declVar(obj)
//obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false)
obj.Type = tc.typeFor(nil, obj.Decl.(*ast.ValueSpec).Type, false)
case ast.Fun:
obj.Type = ast.NewType(ast.Function)
obj.Type = NewType(Function)
t := obj.Decl.(*ast.FuncDecl).Type
tc.declSignature(obj.Type, nil, t.Params, t.Results)
tc.declSignature(obj.Type.(*Type), nil, t.Params, t.Results)
default:
// type objects have non-nil types when resolve is called
......@@ -300,32 +302,34 @@ func (tc *typechecker) resolve(obj *ast.Object) {
}
// resolve type objects
if typ.Form == ast.Unresolved {
if typ.Form == Unresolved {
tc.typeFor(typ, typ.Obj.Decl.(*ast.TypeSpec).Type, false)
// provide types for all methods
for _, obj := range typ.Scope.Objects {
if obj.Kind == ast.Fun {
assert(obj.Type == nil)
obj.Type = ast.NewType(ast.Method)
obj.Type = NewType(Method)
f := obj.Decl.(*ast.FuncDecl)
t := f.Type
tc.declSignature(obj.Type, f.Recv, t.Params, t.Results)
tc.declSignature(obj.Type.(*Type), f.Recv, t.Params, t.Results)
}
}
}
}
func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *ast.Type) {
func (tc *typechecker) checkBlock(body []ast.Stmt, ftype *Type) {
tc.openScope()
defer tc.closeScope()
// inject function/method parameters into block scope, if any
if ftype != nil {
for _, par := range ftype.Params.Objects {
obj := tc.topScope.Insert(par)
assert(obj == par) // ftype has no double declarations
if par.Name != "_" {
obj := tc.topScope.Insert(par)
assert(obj == par) // ftype has no double declarations
}
}
}
......@@ -362,8 +366,8 @@ func (tc *typechecker) declFields(scope *ast.Scope, fields *ast.FieldList, ref b
}
func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.FieldList) {
assert((typ.Form == ast.Method) == (recv != nil))
func (tc *typechecker) declSignature(typ *Type, recv, params, results *ast.FieldList) {
assert((typ.Form == Method) == (recv != nil))
typ.Params = ast.NewScope(nil)
tc.declFields(typ.Params, recv, true)
tc.declFields(typ.Params, params, true)
......@@ -371,7 +375,7 @@ func (tc *typechecker) declSignature(typ *ast.Type, recv, params, results *ast.F
}
func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Type) {
func (tc *typechecker) typeFor(def *Type, x ast.Expr, ref bool) (typ *Type) {
x = unparen(x)
// type name
......@@ -381,10 +385,10 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if obj.Kind != ast.Typ {
tc.Errorf(t.Pos(), "%s is not a type", t.Name)
if def == nil {
typ = ast.NewType(ast.BadType)
typ = NewType(BadType)
} else {
typ = def
typ.Form = ast.BadType
typ.Form = BadType
}
typ.Expr = x
return
......@@ -393,7 +397,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if !ref {
tc.resolve(obj) // check for cycles even if type resolved
}
typ = obj.Type
typ = obj.Type.(*Type)
if def != nil {
// new type declaration: copy type structure
......@@ -410,7 +414,7 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
// type literal
typ = def
if typ == nil {
typ = ast.NewType(ast.BadType)
typ = NewType(BadType)
}
typ.Expr = x
......@@ -419,42 +423,42 @@ func (tc *typechecker) typeFor(def *ast.Type, x ast.Expr, ref bool) (typ *ast.Ty
if debug {
fmt.Println("qualified identifier unimplemented")
}
typ.Form = ast.BadType
typ.Form = BadType
case *ast.StarExpr:
typ.Form = ast.Pointer
typ.Form = Pointer
typ.Elt = tc.typeFor(nil, t.X, true)
case *ast.ArrayType:
if t.Len != nil {
typ.Form = ast.Array
typ.Form = Array
// TODO(gri) compute the real length
// (this may call resolve recursively)
(*typ).N = 42
} else {
typ.Form = ast.Slice
typ.Form = Slice
}
typ.Elt = tc.typeFor(nil, t.Elt, t.Len == nil)
case *ast.StructType:
typ.Form = ast.Struct
typ.Form = Struct
tc.declFields(typ.Scope, t.Fields, false)
case *ast.FuncType:
typ.Form = ast.Function
typ.Form = Function
tc.declSignature(typ, nil, t.Params, t.Results)
case *ast.InterfaceType:
typ.Form = ast.Interface
typ.Form = Interface
tc.declFields(typ.Scope, t.Methods, true)
case *ast.MapType:
typ.Form = ast.Map
typ.Form = Map
typ.Key = tc.typeFor(nil, t.Key, true)
typ.Elt = tc.typeFor(nil, t.Value, true)
case *ast.ChanType:
typ.Form = ast.Channel
typ.Form = Channel
typ.N = uint(t.Dir)
typ.Elt = tc.typeFor(nil, t.Value, true)
......
......@@ -93,7 +93,7 @@ func expectedErrors(t *testing.T, pkg *ast.Package) (list scanner.ErrorList) {
func testFilter(f *os.FileInfo) bool {
return strings.HasSuffix(f.Name, ".go") && f.Name[0] != '.'
return strings.HasSuffix(f.Name, ".src") && f.Name[0] != '.'
}
......
......@@ -24,8 +24,8 @@ func init() {
Universe = ast.NewScope(nil)
// basic types
for n, name := range ast.BasicTypes {
typ := ast.NewType(ast.Basic)
for n, name := range BasicTypes {
typ := NewType(Basic)
typ.N = n
obj := ast.NewObj(ast.Typ, name)
obj.Type = typ
......
......@@ -50,7 +50,7 @@ func testError(t *testing.T) {
func TestUintCodec(t *testing.T) {
defer testError(t)
b := new(bytes.Buffer)
encState := newEncoderState(nil, b)
encState := newEncoderState(b)
for _, tt := range encodeT {
b.Reset()
encState.encodeUint(tt.x)
......@@ -58,7 +58,7 @@ func TestUintCodec(t *testing.T) {
t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes())
}
}
decState := newDecodeState(nil, b)
decState := newDecodeState(b)
for u := uint64(0); ; u = (u + 1) * 7 {
b.Reset()
encState.encodeUint(u)
......@@ -75,9 +75,9 @@ func TestUintCodec(t *testing.T) {
func verifyInt(i int64, t *testing.T) {
defer testError(t)
var b = new(bytes.Buffer)
encState := newEncoderState(nil, b)
encState := newEncoderState(b)
encState.encodeInt(i)
decState := newDecodeState(nil, b)
decState := newDecodeState(b)
decState.buf = make([]byte, 8)
j := decState.decodeInt()
if i != j {
......@@ -111,9 +111,16 @@ var complexResult = []byte{0x07, 0xFE, 0x31, 0x40, 0xFE, 0x33, 0x40}
// The result of encoding "hello" with field number 7
var bytesResult = []byte{0x07, 0x05, 'h', 'e', 'l', 'l', 'o'}
func newencoderState(b *bytes.Buffer) *encoderState {
func newDecodeState(buf *bytes.Buffer) *decoderState {
d := new(decoderState)
d.b = buf
d.buf = make([]byte, uint64Size)
return d
}
func newEncoderState(b *bytes.Buffer) *encoderState {
b.Reset()
state := newEncoderState(nil, b)
state := &encoderState{enc: nil, b: b}
state.fieldnum = -1
return state
}
......@@ -127,7 +134,7 @@ func TestScalarEncInstructions(t *testing.T) {
{
data := struct{ a bool }{true}
instr := &encInstr{encBool, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(boolResult, b.Bytes()) {
t.Errorf("bool enc instructions: expected % x got % x", boolResult, b.Bytes())
......@@ -139,7 +146,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a int }{17}
instr := &encInstr{encInt, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int enc instructions: expected % x got % x", signedResult, b.Bytes())
......@@ -151,7 +158,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a uint }{17}
instr := &encInstr{encUint, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint enc instructions: expected % x got % x", unsignedResult, b.Bytes())
......@@ -163,7 +170,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a int8 }{17}
instr := &encInstr{encInt8, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int8 enc instructions: expected % x got % x", signedResult, b.Bytes())
......@@ -175,7 +182,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a uint8 }{17}
instr := &encInstr{encUint8, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint8 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
......@@ -187,7 +194,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a int16 }{17}
instr := &encInstr{encInt16, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int16 enc instructions: expected % x got % x", signedResult, b.Bytes())
......@@ -199,7 +206,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a uint16 }{17}
instr := &encInstr{encUint16, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint16 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
......@@ -211,7 +218,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a int32 }{17}
instr := &encInstr{encInt32, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int32 enc instructions: expected % x got % x", signedResult, b.Bytes())
......@@ -223,7 +230,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a uint32 }{17}
instr := &encInstr{encUint32, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint32 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
......@@ -235,7 +242,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a int64 }{17}
instr := &encInstr{encInt64, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(signedResult, b.Bytes()) {
t.Errorf("int64 enc instructions: expected % x got % x", signedResult, b.Bytes())
......@@ -247,7 +254,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a uint64 }{17}
instr := &encInstr{encUint64, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(unsignedResult, b.Bytes()) {
t.Errorf("uint64 enc instructions: expected % x got % x", unsignedResult, b.Bytes())
......@@ -259,7 +266,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a float32 }{17}
instr := &encInstr{encFloat32, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(floatResult, b.Bytes()) {
t.Errorf("float32 enc instructions: expected % x got % x", floatResult, b.Bytes())
......@@ -271,7 +278,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a float64 }{17}
instr := &encInstr{encFloat64, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(floatResult, b.Bytes()) {
t.Errorf("float64 enc instructions: expected % x got % x", floatResult, b.Bytes())
......@@ -283,7 +290,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a []byte }{[]byte("hello")}
instr := &encInstr{encUint8Array, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(bytesResult, b.Bytes()) {
t.Errorf("bytes enc instructions: expected % x got % x", bytesResult, b.Bytes())
......@@ -295,7 +302,7 @@ func TestScalarEncInstructions(t *testing.T) {
b.Reset()
data := struct{ a string }{"hello"}
instr := &encInstr{encString, 6, 0, 0}
state := newencoderState(b)
state := newEncoderState(b)
instr.op(instr, state, unsafe.Pointer(&data))
if !bytes.Equal(bytesResult, b.Bytes()) {
t.Errorf("string enc instructions: expected % x got % x", bytesResult, b.Bytes())
......@@ -303,7 +310,7 @@ func TestScalarEncInstructions(t *testing.T) {
}
}
func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) {
func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) {
defer testError(t)
v := int(state.decodeUint())
if v+state.fieldnum != 6 {
......@@ -313,9 +320,9 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
state.fieldnum = 6
}
func newDecodeStateFromData(data []byte) *decodeState {
func newDecodeStateFromData(data []byte) *decoderState {
b := bytes.NewBuffer(data)
state := newDecodeState(nil, b)
state := newDecodeState(b)
state.fieldnum = -1
return state
}
......@@ -997,9 +1004,9 @@ func TestInvalidField(t *testing.T) {
var bad0 Bad0
bad0.CH = make(chan int)
b := new(bytes.Buffer)
var nilEncoder *Encoder
err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil {
dummyEncoder := new(Encoder) // sufficient for this purpose.
dummyEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err := dummyEncoder.err; err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "type") < 0 {
t.Error("expected type error; got", err)
......
......@@ -5,6 +5,7 @@
package gob
import (
"bufio"
"bytes"
"io"
"os"
......@@ -21,7 +22,7 @@ type Decoder struct {
wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects
countState *decodeState // reads counts from wire
freeList *decoderState // list of free decoderStates; avoids reallocation
countBuf []byte // used for decoding integers while parsing messages
tmp []byte // temporary storage for i/o; saves reallocating
err os.Error
......@@ -30,7 +31,7 @@ type Decoder struct {
// NewDecoder returns a new decoder that reads from the io.Reader.
func NewDecoder(r io.Reader) *Decoder {
dec := new(Decoder)
dec.r = r
dec.r = bufio.NewReader(r)
dec.wireType = make(map[typeId]*wireType)
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
dec.ignorerCache = make(map[typeId]**decEngine)
......@@ -49,7 +50,7 @@ func (dec *Decoder) recvType(id typeId) {
// Type:
wire := new(wireType)
dec.err = dec.decodeValue(tWireType, reflect.NewValue(wire))
dec.decodeValue(tWireType, reflect.NewValue(wire))
if dec.err != nil {
return
}
......@@ -184,7 +185,7 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error {
dec.err = nil
id := dec.decodeTypeSequence(false)
if dec.err == nil {
dec.err = dec.decodeValue(id, value)
dec.decodeValue(id, value)
}
return dec.err
}
......
......@@ -19,7 +19,9 @@ type Encoder struct {
w []io.Writer // where to send the data
sent map[reflect.Type]typeId // which types we've already sent
countState *encoderState // stage for writing counts
freeList *encoderState // list of free encoderStates; avoids reallocation
buf []byte // for collecting the output.
byteBuf bytes.Buffer // buffer for top-level encoderState
err os.Error
}
......@@ -28,7 +30,7 @@ func NewEncoder(w io.Writer) *Encoder {
enc := new(Encoder)
enc.w = []io.Writer{w}
enc.sent = make(map[reflect.Type]typeId)
enc.countState = newEncoderState(enc, new(bytes.Buffer))
enc.countState = enc.newEncoderState(new(bytes.Buffer))
return enc
}
......@@ -78,12 +80,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
}
}
// sendActualType sends the requested type, without further investigation, unless
// it's been sent before.
func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
if _, alreadySent := enc.sent[actual]; alreadySent {
return false
}
typeLock.Lock()
info, err := getTypeInfo(ut)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type, both what the user gave us and the base type.
enc.sent[ut.base] = info.id
if ut.user != ut.base {
enc.sent[ut.user] = info.id
}
// Now send the inner types
switch st := actual.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
}
// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
// Drill down to the base type.
ut := userType(origt)
rt := ut.base
if ut.isGobEncoder {
// The rules are different: regardless of the underlying type's representation,
// we need to tell the other side that this exact type is a GobEncoder.
return enc.sendActualType(w, state, ut, ut.user)
}
switch rt := rt.(type) {
// It's a concrete value, so drill down to the base type.
switch rt := ut.base.(type) {
default:
// Basic types and interfaces do not need to be described.
return
......@@ -109,43 +156,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
return
}
// Have we already sent this type? This time we ask about the base type.
if _, alreadySent := enc.sent[rt]; alreadySent {
return
}
// Need to send it.
typeLock.Lock()
info, err := getTypeInfo(rt)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
// Send the pair (-id, type)
// Id:
state.encodeInt(-int64(info.id))
// Type:
enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
enc.writeMessage(w, state.b)
if enc.err != nil {
return
}
// Remember we've sent this type.
enc.sent[rt] = info.id
// Remember we've sent the top-level, possibly indirect type too.
enc.sent[origt] = info.id
// Now send the inner types
switch st := rt.(type) {
case *reflect.StructType:
for i := 0; i < st.NumField(); i++ {
enc.sendType(w, state, st.Field(i).Type)
}
case reflect.ArrayOrSliceType:
enc.sendType(w, state, st.Elem())
}
return true
return enc.sendActualType(w, state, ut, ut.base)
}
// Encode transmits the data item represented by the empty interface value,
......@@ -159,11 +170,14 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side.
// First, have we already sent this (base) type?
base := ut.base
if _, alreadySent := enc.sent[base]; !alreadySent {
// First, have we already sent this type?
rt := ut.base
if ut.isGobEncoder {
rt = ut.user
}
if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
sent := enc.sendType(w, state, base)
sent := enc.sendType(w, state, rt)
if enc.err != nil {
return
}
......@@ -172,13 +186,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
// need to send the type info but we do need to update enc.sent.
if !sent {
typeLock.Lock()
info, err := getTypeInfo(base)
info, err := getTypeInfo(ut)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
enc.sent[base] = info.id
enc.sent[rt] = info.id
}
}
}
......@@ -206,7 +220,8 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
}
enc.err = nil
state := newEncoderState(enc, new(bytes.Buffer))
enc.byteBuf.Reset()
state := enc.newEncoderState(&enc.byteBuf)
enc.sendTypeDescriptor(enc.writer(), state, ut)
enc.sendTypeId(state, ut)
......@@ -215,12 +230,11 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
}
// Encode the object.
err = enc.encode(state.b, value, ut)
if err != nil {
enc.setError(err)
} else {
enc.encode(state.b, value, ut)
if enc.err == nil {
enc.writeMessage(enc.writer(), state.b)
}
enc.freeEncoderState(state)
return enc.err
}
// Copyright 20011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file contains tests of the GobEncoder/GobDecoder support.
package gob
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
)
// Types that implement the GobEncoder/Decoder interfaces.
type ByteStruct struct {
a byte // not an exported field
}
type StringStruct struct {
s string // not an exported field
}
type Gobber int
type ValueGobber string // encodes with a value, decodes with a pointer.
// The relevant methods
func (g *ByteStruct) GobEncode() ([]byte, os.Error) {
b := make([]byte, 3)
b[0] = g.a
b[1] = g.a + 1
b[2] = g.a + 2
return b, nil
}
func (g *ByteStruct) GobDecode(data []byte) os.Error {
if g == nil {
return os.ErrorString("NIL RECEIVER")
}
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
g.a = data[0]
for i, c := range data {
if c != g.a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
return nil
}
func (g *StringStruct) GobEncode() ([]byte, os.Error) {
return []byte(g.s), nil
}
func (g *StringStruct) GobDecode(data []byte) os.Error {
// Expect N sequential-valued bytes.
if len(data) == 0 {
return os.EOF
}
a := data[0]
for i, c := range data {
if c != a+byte(i) {
return os.ErrorString("invalid data sequence")
}
}
g.s = string(data)
return nil
}
func (g *Gobber) GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
}
func (g *Gobber) GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
return err
}
func (v ValueGobber) GobEncode() ([]byte, os.Error) {
return []byte(fmt.Sprintf("VALUE=%s", v)), nil
}
func (v *ValueGobber) GobDecode(data []byte) os.Error {
_, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
return err
}
// Structs that include GobEncodable fields.
type GobTest0 struct {
X int // guarantee we have something in common with GobTest*
G *ByteStruct
}
type GobTest1 struct {
X int // guarantee we have something in common with GobTest*
G *StringStruct
}
type GobTest2 struct {
X int // guarantee we have something in common with GobTest*
G string // not a GobEncoder - should give us errors
}
type GobTest3 struct {
X int // guarantee we have something in common with GobTest*
G *Gobber
}
type GobTest4 struct {
X int // guarantee we have something in common with GobTest*
V ValueGobber
}
type GobTest5 struct {
X int // guarantee we have something in common with GobTest*
V *ValueGobber
}
type GobTestIgnoreEncoder struct {
X int // guarantee we have something in common with GobTest*
}
type GobTestValueEncDec struct {
X int // guarantee we have something in common with GobTest*
G StringStruct // not a pointer.
}
type GobTestIndirectEncDec struct {
X int // guarantee we have something in common with GobTest*
G ***StringStruct // indirections to the receiver.
}
func TestGobEncoderField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// Now a field that's not a structure.
b.Reset()
gobber := Gobber(23)
err = enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest3)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if *y.G != 23 {
t.Errorf("expected '23 got %d", *y.G)
}
}
// Even though the field is a value, we can still take its address
// and should be able to call the methods.
func TestGobEncoderValueField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTestValueEncDec{17, StringStruct{"HIJKL"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestValueEncDec)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.s != "HIJKL" {
t.Errorf("expected `HIJKL` got %s", x.G.s)
}
}
// GobEncode/Decode should work even if the value is
// more indirect than the receiver.
func TestGobEncoderIndirectField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
s := &StringStruct{"HIJKL"}
sp := &s
err := enc.Encode(GobTestIndirectEncDec{17, &sp})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIndirectEncDec)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if (***x.G).s != "HIJKL" {
t.Errorf("expected `HIJKL` got %s", (***x.G).s)
}
}
// As long as the fields have the same name and implement the
// interface, we can cross-connect them. Not sure it's useful
// and may even be bad but it works and it's hard to prevent
// without exposing the contents of the object, which would
// defeat the purpose.
func TestGobEncoderFieldsOfDifferentType(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest0)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.G.a != 'A' {
t.Errorf("expected 'A' got %c", x.G.a)
}
// now the other direction, byte in field to string in field
b.Reset()
err = enc.Encode(GobTest0{17, &ByteStruct{'X'}})
if err != nil {
t.Fatal("encode error:", err)
}
y := new(GobTest1)
err = dec.Decode(y)
if err != nil {
t.Fatal("decode error:", err)
}
if y.G.s != "XYZ" {
t.Fatalf("expected `XYZ` got %c", y.G.s)
}
}
// Test that we can encode a value and decode into a pointer.
func TestGobEncoderValueEncoder(t *testing.T) {
// first, string in field to byte in field
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest4{17, ValueGobber("hello")})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTest5)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if *x.V != "hello" {
t.Errorf("expected `hello` got %s", x.V)
}
}
func TestGobEncoderFieldTypeError(t *testing.T) {
// GobEncoder to non-decoder: error
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := &GobTest2{}
err = dec.Decode(x)
if err == nil {
t.Fatal("expected decode error for mismatched fields (encoder to non-decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
// Non-encoder to GobDecoder: error
b.Reset()
err = enc.Encode(GobTest2{17, "ABC"})
if err != nil {
t.Fatal("encode error:", err)
}
y := &GobTest1{}
err = dec.Decode(y)
if err == nil {
t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)")
}
if strings.Index(err.String(), "type") < 0 {
t.Fatal("expected type error; got", err)
}
}
// Even though ByteStruct is a struct, it's treated as a singleton at the top level.
func TestGobEncoderStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(&ByteStruct{'A'})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(ByteStruct)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.a != 'A' {
t.Errorf("expected 'A' got %c", x.a)
}
}
func TestGobEncoderNonStructSingleton(t *testing.T) {
b := new(bytes.Buffer)
enc := NewEncoder(b)
err := enc.Encode(Gobber(1234))
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
var x Gobber
err = dec.Decode(&x)
if err != nil {
t.Fatal("decode error:", err)
}
if x != 1234 {
t.Errorf("expected 1234 got %c", x)
}
}
func TestGobEncoderIgnoreStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
func TestGobEncoderIgnoreNonStructField(t *testing.T) {
b := new(bytes.Buffer)
// First a field that's a structure.
enc := NewEncoder(b)
gobber := Gobber(23)
err := enc.Encode(GobTest3{17, &gobber})
if err != nil {
t.Fatal("encode error:", err)
}
dec := NewDecoder(b)
x := new(GobTestIgnoreEncoder)
err = dec.Decode(x)
if err != nil {
t.Fatal("decode error:", err)
}
if x.X != 17 {
t.Errorf("expected 17 got %c", x.X)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gob
import (
"bytes"
"fmt"
"io"
"os"
"runtime"
"testing"
)
type Bench struct {
A int
B float64
C string
D []byte
}
func benchmarkEndToEnd(r io.Reader, w io.Writer, b *testing.B) {
b.StopTimer()
enc := NewEncoder(w)
dec := NewDecoder(r)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
b.StartTimer()
for i := 0; i < b.N; i++ {
if enc.Encode(bench) != nil {
panic("encode error")
}
if dec.Decode(bench) != nil {
panic("decode error")
}
}
}
func BenchmarkEndToEndPipe(b *testing.B) {
r, w, err := os.Pipe()
if err != nil {
panic("can't get pipe:" + err.String())
}
benchmarkEndToEnd(r, w, b)
}
func BenchmarkEndToEndByteBuffer(b *testing.B) {
var buf bytes.Buffer
benchmarkEndToEnd(&buf, &buf, b)
}
func TestCountEncodeMallocs(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
mallocs := 0 - runtime.MemStats.Mallocs
const count = 1000
for i := 0; i < count; i++ {
err := enc.Encode(bench)
if err != nil {
t.Fatal("encode:", err)
}
}
mallocs += runtime.MemStats.Mallocs
fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count)
}
func TestCountDecodeMallocs(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
const count = 1000
for i := 0; i < count; i++ {
err := enc.Encode(bench)
if err != nil {
t.Fatal("encode:", err)
}
}
dec := NewDecoder(&buf)
mallocs := 0 - runtime.MemStats.Mallocs
for i := 0; i < count; i++ {
*bench = Bench{}
err := dec.Decode(&bench)
if err != nil {
t.Fatal("decode:", err)
}
}
mallocs += runtime.MemStats.Mallocs
fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count)
}
......@@ -26,7 +26,7 @@ var basicTypes = []typeT{
func getTypeUnlocked(name string, rt reflect.Type) gobType {
typeLock.Lock()
defer typeLock.Unlock()
t, err := getType(name, rt)
t, err := getBaseType(name, rt)
if err != nil {
panic("getTypeUnlocked: " + err.String())
}
......@@ -126,27 +126,27 @@ func TestMapType(t *testing.T) {
}
type Bar struct {
x string
X string
}
// This structure has pointers and refers to itself, making it a good test case.
type Foo struct {
a int
b int32 // will become int
c string
d []byte
e *float64 // will become float64
f ****float64 // will become float64
g *Bar
h *Bar // should not interpolate the definition of Bar again
i *Foo // will not explode
A int
B int32 // will become int
C string
D []byte
E *float64 // will become float64
F ****float64 // will become float64
G *Bar
H *Bar // should not interpolate the definition of Bar again
I *Foo // will not explode
}
func TestStructType(t *testing.T) {
sstruct := getTypeUnlocked("Foo", reflect.Typeof(Foo{}))
str := sstruct.string()
// If we can print it correctly, we built it correctly.
expected := "Foo = struct { a int; b int; c string; d bytes; e float; f float; g Bar = struct { x string; }; h Bar; i Foo; }"
expected := "Foo = struct { A int; B int; C string; D bytes; E float; F float; G Bar = struct { X string; }; H Bar; I Foo; }"
if str != expected {
t.Errorf("struct printed as %q; expected %q", str, expected)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment