Commit 02bb39f4 by Edward Thomson

stream registration: take an enum type

Accept an enum (`git_stream_t`) during custom stream registration that
indicates whether the registration structure should be used for standard
(non-TLS) streams or TLS streams.
parent 52478d7d
......@@ -72,19 +72,31 @@ typedef struct {
} git_stream_registration;
/**
* The type of stream to register.
*/
typedef enum {
/** A standard (non-TLS) socket. */
GIT_STREAM_STANDARD = 1,
/** A TLS-encrypted socket. */
GIT_STREAM_TLS = 2,
} git_stream_t;
/**
* Register stream constructors for the library to use
*
* If a registration structure is already set, it will be overwritten.
* Pass `NULL` in order to deregister the current constructor and return
* to the system defaults.
*
* The type parameter may be a bitwise AND of types.
*
* @param type the type or types of stream to register
* @param registration the registration data
* @param tls 1 if the registration is for TLS streams, 0 for regular
* (insecure) sockets
* @return 0 or an error code
*/
GIT_EXTERN(int) git_stream_register(
int tls, git_stream_registration *registration);
git_stream_t type, git_stream_registration *registration);
/** @name Deprecated TLS Stream Registration Functions
*
......
......@@ -36,22 +36,42 @@ int git_stream_registry_global_init(void)
return 0;
}
int git_stream_registry_lookup(git_stream_registration *out, int tls)
GIT_INLINE(void) stream_registration_cpy(
git_stream_registration *target,
git_stream_registration *src)
{
git_stream_registration *target = tls ?
&stream_registry.callbacks :
&stream_registry.tls_callbacks;
if (src)
memcpy(target, src, sizeof(git_stream_registration));
else
memset(target, 0, sizeof(git_stream_registration));
}
int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type)
{
git_stream_registration *target;
int error = GIT_ENOTFOUND;
assert(out);
switch(type) {
case GIT_STREAM_STANDARD:
target = &stream_registry.callbacks;
break;
case GIT_STREAM_TLS:
target = &stream_registry.tls_callbacks;
break;
default:
assert(0);
return -1;
}
if (git_rwlock_rdlock(&stream_registry.lock) < 0) {
giterr_set(GITERR_OS, "failed to lock stream registry");
return -1;
}
if (target->init) {
memcpy(out, target, sizeof(git_stream_registration));
stream_registration_cpy(out, target);
error = 0;
}
......@@ -59,12 +79,8 @@ int git_stream_registry_lookup(git_stream_registration *out, int tls)
return error;
}
int git_stream_register(int tls, git_stream_registration *registration)
int git_stream_register(git_stream_t type, git_stream_registration *registration)
{
git_stream_registration *target = tls ?
&stream_registry.callbacks :
&stream_registry.tls_callbacks;
assert(!registration || registration->init);
GITERR_CHECK_VERSION(registration, GIT_STREAM_VERSION, "stream_registration");
......@@ -74,10 +90,11 @@ int git_stream_register(int tls, git_stream_registration *registration)
return -1;
}
if (registration)
memcpy(target, registration, sizeof(git_stream_registration));
else
memset(target, 0, sizeof(git_stream_registration));
if ((type & GIT_STREAM_STANDARD) == GIT_STREAM_STANDARD)
stream_registration_cpy(&stream_registry.callbacks, registration);
if ((type & GIT_STREAM_TLS) == GIT_STREAM_TLS)
stream_registration_cpy(&stream_registry.tls_callbacks, registration);
git_rwlock_wrunlock(&stream_registry.lock);
return 0;
......@@ -92,8 +109,8 @@ int git_stream_register_tls(git_stream_cb ctor)
registration.init = ctor;
registration.wrap = NULL;
return git_stream_register(1, &registration);
return git_stream_register(GIT_STREAM_TLS, &registration);
} else {
return git_stream_register(1, NULL);
return git_stream_register(GIT_STREAM_TLS, NULL);
}
}
......@@ -14,6 +14,6 @@
int git_stream_registry_global_init(void);
/** Lookup a stream registration. */
extern int git_stream_registry_lookup(git_stream_registration *out, int tls);
extern int git_stream_registry_lookup(git_stream_registration *out, git_stream_t type);
#endif
......@@ -224,7 +224,7 @@ int git_socket_stream_new(
assert(out && host && port);
if ((error = git_stream_registry_lookup(&custom, 0)) == 0)
if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_STANDARD)) == 0)
init = custom.init;
else if (error == GIT_ENOTFOUND)
init = default_socket_stream_new;
......
......@@ -23,7 +23,7 @@ int git_tls_stream_new(git_stream **out, const char *host, const char *port)
assert(out && host && port);
if ((error = git_stream_registry_lookup(&custom, 1)) == 0) {
if ((error = git_stream_registry_lookup(&custom, GIT_STREAM_TLS)) == 0) {
init = custom.init;
} else if (error == GIT_ENOTFOUND) {
#ifdef GIT_SECURE_TRANSPORT
......@@ -52,7 +52,7 @@ int git_tls_stream_wrap(git_stream **out, git_stream *in, const char *host)
assert(out && in);
if (git_stream_registry_lookup(&custom, 1) == 0) {
if (git_stream_registry_lookup(&custom, GIT_STREAM_TLS) == 0) {
wrap = custom.wrap;
} else {
#ifdef GIT_SECURE_TRANSPORT
......
......@@ -7,6 +7,11 @@
static git_stream test_stream;
static int ctor_called;
void test_core_stream__cleanup(void)
{
cl_git_pass(git_stream_register(GIT_STREAM_TLS | GIT_STREAM_STANDARD, NULL));
}
static int test_stream_init(git_stream **out, const char *host, const char *port)
{
GIT_UNUSED(host);
......@@ -39,14 +44,14 @@ void test_core_stream__register_insecure(void)
registration.wrap = test_stream_wrap;
ctor_called = 0;
cl_git_pass(git_stream_register(0, &registration));
cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, &registration));
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
ctor_called = 0;
stream = NULL;
cl_git_pass(git_stream_register(0, NULL));
cl_git_pass(git_stream_register(GIT_STREAM_STANDARD, NULL));
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(0, ctor_called);
......@@ -66,14 +71,14 @@ void test_core_stream__register_tls(void)
registration.wrap = test_stream_wrap;
ctor_called = 0;
cl_git_pass(git_stream_register(1, &registration));
cl_git_pass(git_stream_register(GIT_STREAM_TLS, &registration));
cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
ctor_called = 0;
stream = NULL;
cl_git_pass(git_stream_register(1, NULL));
cl_git_pass(git_stream_register(GIT_STREAM_TLS, NULL));
error = git_tls_stream_new(&stream, "localhost", "443");
/* We don't have TLS support enabled, or we're on Windows,
......@@ -91,6 +96,28 @@ void test_core_stream__register_tls(void)
git_stream_free(stream);
}
void test_core_stream__register_both(void)
{
git_stream *stream;
git_stream_registration registration = {0};
registration.version = 1;
registration.init = test_stream_init;
registration.wrap = test_stream_wrap;
cl_git_pass(git_stream_register(GIT_STREAM_STANDARD | GIT_STREAM_TLS, &registration));
ctor_called = 0;
cl_git_pass(git_tls_stream_new(&stream, "localhost", "443"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
ctor_called = 0;
cl_git_pass(git_socket_stream_new(&stream, "localhost", "80"));
cl_assert_equal_i(1, ctor_called);
cl_assert_equal_p(&test_stream, stream);
}
void test_core_stream__register_tls_deprecated(void)
{
git_stream *stream;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment