Tuesday, March 30, 2010

Private Accessor T4 Generation for NUnit

As mentioned in a previous post I have been experimenting for quite some time with generating private accessors for NUnit.  If you use MsTest, then you can use the Visual Studio private accessor generator which works a treat, but not for NUnit sadly. For me this is a nice draw card for MsTest over NUnit.  The MsTest generated accessors are cleaner than my generation tool here, but hey, I don't have Microsoft's development budget :-)

There are a few hacks in it, and I might refine it over time when the needs arise. The biggest limitation is not recursively generating accessors for all private and internal classes found inside a targeted class.  In my projects I get around this easily as I am an avid user of interfaces and abstract classes.  These can be substituted in most cases using my internal type mapping dictionary. For example if you want to generate an accessor for an internal class it must have an interface to use with this accessor generator. Not as bad as it sounds if you are interfacing classes for extensive unit testing.

Download the sample project to see an example of  T4 consuming the private accessor generator. But here's the generator code for a quick glance.


namespace ReesTestToolkit
{
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Reflection;
    using System.Text;

    /// <summary>
    /// A Code generator designed to be used with T4 code generation. This generator creates private accessor wrappers for a give type.
    /// </summary>
    public class PrivateAccessorGenerator
    {
        private readonly Action<string> errorWriter;
        private readonly Type targetType;
        private readonly Action<string> warningWriter;
        private readonly Action<string> writer;
        private bool isStaticClass;

        /// <summary>
        /// Initializes a new instance of the <see cref="PrivateAccessorGenerator"/> class.
        /// </summary>
        /// <param name="targetType">Type of the target to generate a private accessor wrapper for.</param>
        /// <param name="writer">The code writer.</param>
        /// <param name="errorWriter">The error writer.</param>
        /// <param name="warningWriter">The warning writer.</param>
        public PrivateAccessorGenerator(Type targetType, Action<string> writer, Action<string> errorWriter, Action<string> warningWriter)
        {
            this.errorWriter = errorWriter;
            this.warningWriter = warningWriter;
            this.writer = writer;
            this.targetType = targetType;

            MemberExcludeList = new List<string>() { "MemberwiseClone", "Finalize", "Equals", "GetHashCode", "GetType" };
            InternalTypeMapping = new Dictionary<Type, Type>();
        }

        /// <summary>
        /// Gets the internal type mapping.  This is the types to use in place of inaccessable internal types.
        /// Specify the type as the FullName of the type.
        /// If no mapping is specified here the <see cref="System.Object"/> is used instead.
        /// </summary>
        /// <value>The internal type mapping.</value>
        public IDictionary<Type, Type> InternalTypeMapping { get; private set; }

        /// <summary>
        /// Gets the member exclude list. These members will be excluded in the resulting private accessor.
        /// </summary>
        /// <value>The member exclude list.</value>
        public IList<string> MemberExcludeList { get; private set; }

        /// <summary>
        /// Gets the number of properties generated. Used for testing of the generator.
        /// </summary>
        /// <value>The property count.</value>
        public int PropertyCount { get; private set; }

        /// <summary>
        /// Gets the number of Methods generated. Used for testing of the generator.
        /// </summary>
        /// <value>The Method count.</value>
        public int MethodCount { get; private set; }

        /// <summary>
        /// Gets the number of Constants generated. Used for testing of the generator.
        /// </summary>
        /// <value>The Constants count.</value>
        public int ConstantCount { get; private set; }

        /// <summary>
        /// Gets the number of Fields generated. Used for testing of the generator.
        /// </summary>
        /// <value>The Fields count.</value>
        public int FieldCount { get; private set; }

        /// <summary>
        /// Gets the constructor count.
        /// </summary>
        /// <value>The constructor count.</value>
        public int ConstructorCount { get; private set; }

        /// <summary>
        /// Generates this instance.
        /// </summary>
        public void Generate()
        {
            try
            {
                this.isStaticClass = this.targetType.GetConstructors().Length == 0 && this.targetType.IsAbstract && this.targetType.IsSealed;

                this.Preconditions();

                this.WriteHeader();

                this.writer("        // Fields...");
                this.WriteStaticFields();

                if (!this.isStaticClass)
                {
                    this.WriteFields();
                }

                if (!this.isStaticClass)
                {
                    this.writer("        // Constructors...");
                    this.WriteConstructors();
                }

                this.writer("        // Properties...");
                this.WriteProperties(true);

                if (!this.isStaticClass)
                {
                    this.WriteProperties(false);
                }

                this.writer("        // Methods...");
                this.WriteMethods(true);
                if (!this.isStaticClass)
                {
                    this.WriteMethods(false);
                }
            } catch (Exception ex)
            {
                this.errorWriter(ex.ToString());
            } finally
            {
                this.WriteFooter();
            }
        }

        /// <summary>
        /// Corrects the generic format. This assumes the correct type has already been selected using <see cref="PublicizableType"/>
        /// </summary>
        /// <param name="fullTypeName">Full name of the type.</param>
        /// <param name="genericArgs">The generic args.</param>
        /// <returns>A properly formatted generic type ie: IList[string] instead of IList`1</returns>
        private static string CorrectGenericFormat(string fullTypeName, IEnumerable<Type> genericArgs)
        {
            if (!genericArgs.Any())
            {
                return fullTypeName;
            }

            string typeName = fullTypeName.Substring(0, fullTypeName.IndexOf("`")) + "<";
            bool first = true;
            foreach (var genericArgType in genericArgs)
            {
                if (!first)
                {
                    typeName += ", ";
                }

                if (genericArgType.IsGenericType)
                {
                    typeName += CorrectGenericFormat(genericArgType.FullName, genericArgType.GetGenericArguments());
                } else
                {
                    typeName += genericArgType.Name;
                }

                first = false;
            }

            typeName.Remove(typeName.Length - 2, 2);
            return typeName + ">";
        }

        private static void WriteOneProperty(Action<string> codeWriter, string type, string property, bool staticOnly, bool canWrite, bool canRead)
        {
            if (!canRead && !canWrite)
            {
                return;
            }

            string staticString = staticOnly ? "static " : string.Empty;
            string target = staticOnly ? "null" : "this.target";

            codeWriter(
                string.Format(
                    @"
        public {0}{1} {2} {{",
                    staticString,
                    type,
                    property));

            if (canRead)
            {
                codeWriter(
                    string.Format(
                        @"
            get {{
                try {{
                    return ({0})TargetType.GetProperty(""{1}"", BindingFlags.NonPublic | {3} | BindingFlags.Public).GetValue({2}, new object[] {{ }});
                }} catch (NullReferenceException ex) {{
                    throw new Exception(""The private accessor may be out of date, try regenerating its code. An Object Reference Not Set was thrown."", ex);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}
",
                        type,
                        property,
                        target,
                        staticOnly ? "BindingFlags.Static" : "BindingFlags.Instance"));
            }

            if (canWrite)
            {
                codeWriter(
                    string.Format(
                        @"
            set {{
                try {{
                    TargetType.GetProperty(""{0}"", BindingFlags.NonPublic | {2} | BindingFlags.Public).SetValue({1}, value, new object[] {{ }});
                }} catch (NullReferenceException ex) {{
                    throw new Exception(""The private accessor may be out of date, try regenerating its code. An Object Reference Not Set was thrown."", ex);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}",
                        property,
                        target,
                        staticOnly ? "BindingFlags.Static" : "BindingFlags.Instance"));
            }

            codeWriter("        }");
        }

        private static void WriteOneMethod(
            Action<string> codeWriter,
            string returnType,
            string methodName,
            string methodGeneric,
            string[] argTypes,
            string[] argNames,
            bool isFunction,
            bool isStatic,
            string genericWhereSpec)
        {
            var arguments = new StringBuilder();
            var argumentNamesBuilder = new StringBuilder();
            for (int index = 0; index < argTypes.Length; index++)
            {
                if (index > 0)
                {
                    arguments.Append(", ");
                    argumentNamesBuilder.Append(", ");
                }

                arguments.AppendFormat("{0} {1}", argTypes[index], argNames[index]);
                argumentNamesBuilder.Append(argNames[index]);
            }

            string argumentNames = ", new object[] { }";
            if (argumentNamesBuilder.Length > 0)
            {
                argumentNames = ", new object[] {" + argumentNamesBuilder.ToString() + "}";
            }

            string methodOverload = string.Empty;
            if (methodName == "MemberwiseClone")
            {
                methodOverload = "new ";
            }

            string genericTypeParamName = methodGeneric.Replace("<", string.Empty).Replace(">", string.Empty);
            if (isFunction && returnType == null)
            {
                returnType = genericTypeParamName;
            }

            codeWriter(
                string.Format(
@"
        public {0}{1}{2} {3}{4}({5}) {6}{{",
                isStatic ? "static " : string.Empty,
                methodOverload,
                returnType ?? genericTypeParamName,
                methodName,
                methodGeneric,
                arguments,
                genericWhereSpec));
            codeWriter(
                string.Format(
@"            try {{
                {0} TargetType.GetMethod(""{1}"", BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public).Invoke({2}{3});
            }} catch (NullReferenceException ex) {{
                throw new Exception(""The private accessor may be out of date, try regenerating its code. An Object Reference Not Set was thrown."", ex);
            }} catch (TargetInvocationException ex) {{
                throw ex.InnerException ?? ex;
            }}",
                    isFunction ? "return (" + returnType + ")" : string.Empty,
                    methodName,
                    isStatic ? "null" : "this.target",
                    argumentNames));
            codeWriter("        }");
        }

        private string GetFieldTypeName(FieldInfo field)
        {
            if (!field.FieldType.IsGenericType)
            {
                return PublicizableType(field.FieldType);
            }

            return new GenericTypeStringMaker(PublicizableTypeAsType(field.FieldType)).Digest();
        }

        private string PublicizableType(Type type)
        {
            if (type == null)
            {
                errorWriter("Null type passed to PublicizableType.");
                throw new ArgumentNullException("type", "Null Type passed to type");
            }

            return PublicizableTypeAsType(type).FullName;
        }

        private Type PublicizableTypeAsType(Type type)
        {
            if (type == null)
            {
                errorWriter("Null type passed to PublicizableType.");
                throw new ArgumentNullException("type", "Null Type passed to type");
            }

            if (InternalTypeMapping.ContainsKey(type))
            {
                return InternalTypeMapping[type];
            }

            if (type.IsPublic)
            {
                return type;
            }

            this.warningWriter(String.Format("The internal type {0} has been found and is not mapped, System.Object will be used.", type.Name));

            return typeof(object);
        }

        private void Preconditions()
        {
            if (this.errorWriter == null)
            {
                throw new InvalidOperationException("errorWriter is null");
            }

            if (this.targetType == null)
            {
                this.errorWriter("targetType cannot be null");
            }

            if (this.writer == null)
            {
                this.errorWriter("code writer delegate cannot be null");
            }

            if (!this.targetType.IsPublic && !InternalTypeMapping.ContainsKey(this.targetType))
            {
                errorWriter(
                    string.Format(
                        "Targetted type {0} is internal but no mapping has been supplied to use in place of the internal type.", this.targetType.Name));
            }

            foreach (var pair in InternalTypeMapping.ToArray())
            {
                if (pair.Key.IsPublic)
                {
                    warningWriter(
                        string.Format(
                            "Internal Type mapping warning: Given type {0} is public and usable as is, it does not need to be mapped to another type.",
                            pair.Key.Name));
                    InternalTypeMapping.Remove(pair);
                }

                if (!pair.Value.IsAssignableFrom(pair.Key))
                {
                    errorWriter(string.Format("Internal Type mapping error: Internal type {0} is not assignable into mapped type {1}", pair.Key.Name, pair.Value.Name));
                }
            }
        }

        private void WriteStaticFields()
        {
            foreach (var field in this.targetType.GetFields(BindingFlags.Static | BindingFlags.NonPublic)
                .Where(field => !field.Name.EndsWith("BackingField") && !field.Name.Contains("CS$")))
            {
                if (field.IsLiteral)
                {
                    ConstantCount++;
                } else
                {
                    FieldCount++;
                }

                var returnTypeType = PublicizableTypeAsType(field.FieldType);
                string returnType = returnTypeType.IsGenericType
                                        ? CorrectGenericFormat(returnTypeType.FullName, returnTypeType.GetGenericArguments())
                                        : returnTypeType.FullName;

                this.writer(
                    string.Format(
                        @"
        public static {0} {1} {{
            get {{
                try {{
                    return ({0})TargetType.GetField(""{1}"", BindingFlags.Static | BindingFlags.NonPublic).GetValue(null);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}",
                        returnType,
                        field.Name));
                if (!field.IsLiteral)
                {
                    this.writer(
                        string.Format(
                            @"
            set {{
                try {{
                    TargetType.GetField(""{0}"", BindingFlags.Static | BindingFlags.NonPublic).SetValue(null, value);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}",
                            field.Name));
                }

                this.writer(@"
        }");
            }
        }

        private void WriteConstructors()
        {
            this.writer(
                string.Format(
                    @"
        public {0}_Accessor({1} target) {{
            TargetType = target.GetType();
            this.target = target;
        }}",
                    this.targetType.Name,
                    PublicizableType(this.targetType)));

            foreach (var ctor in this.targetType.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public))
            {
            ConstructorCount++;
                if (ctor.GetParameters().Count() == 0)
                {
                    this.writer(
                        string.Format(
                            @"
        public {0}_Accessor() {{
            this.target = TargetType.GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public, null, new Type[] {{ }}, null).Invoke(new object[] {{ }});
        }}",
           this.targetType.Name));
                } else
                {
                    var typedArgs = new StringBuilder();
                    var typesOnly = new StringBuilder();
                    var namesOnly = new StringBuilder();
                    int count = 0;
                    foreach (var param in ctor.GetParameters())
                    {
                        if (count > 0)
                        {
                            typedArgs.Append(", ");
                            typesOnly.Append(", ");
                            namesOnly.Append(", ");
                        }

                        typedArgs.Append(new GenericTypeStringMaker(PublicizableTypeAsType(param.ParameterType)).Digest());
                        typedArgs.Append(" ");
                        typedArgs.Append(param.Name);

                        typesOnly.Append("typeof(");
                        typesOnly.Append(new GenericTypeStringMaker(PublicizableTypeAsType(param.ParameterType)).Digest());
                        typesOnly.Append(")");

                        namesOnly.Append(param.Name);
                        count++;
                    }

                    this.writer(
                        string.Format(
                            @"
        public {1}_Accessor({0}) {{
            this.target = TargetType.GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public, null, new[] {{ {2} }}, null).Invoke(new object[] {{ {3} }});
        }}",
                            typedArgs,
                            this.targetType.Name,
                            typesOnly,
                            namesOnly));
                }
            }
        }

        private void WriteMethods(bool staticOnly)
        {
            BindingFlags filter = staticOnly ? BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public : BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public;

            var methods = new List<MethodCache>();
            foreach (var method in this.targetType.GetMethods(filter)
                .Where(m => !m.Name.StartsWith("get_")
                    && !m.Name.StartsWith("<")
                    && !m.Name.StartsWith("set_")
                    && !m.Name.Contains("ctor")
                    && !m.Name.StartsWith("add_")
                    && !m.Name.StartsWith("remove_")
                    && !MemberExcludeList.Contains(m.Name)))
            {
                Type returnType = PublicizableTypeAsType(method.ReturnType);
                string methodReturnType;
                if (method.ReturnType.IsGenericType)
                {
                    methodReturnType = new GenericTypeStringMaker(returnType).Digest();
                } else if (returnType.FullName == "System.Void")
                {
                    methodReturnType = "void";
                } else
                {
                    methodReturnType = returnType.FullName;
                }

                string methodGeneric = string.Empty;
                string genericWhereSpec = string.Empty;
                if (method.IsGenericMethod)
                {
                    methodGeneric = "<" + String.Join(",", method.GetGenericArguments().Select(g => g.Name).ToArray()) + ">";
                    genericWhereSpec = String.Join(
                        " ",
                        method.GetGenericArguments().Select(g => string.Format("where {0} : {1} ", g.Name, g.BaseType.Name)).ToArray());

                    genericWhereSpec = genericWhereSpec.Replace("Object", "new()");
                }

                var methodParameters = method.GetParameters()
                    .Select(p => CorrectGenericFormat(PublicizableType(p.ParameterType), p.ParameterType.GetGenericArguments()))
                    .ToArray();

                var cachedMethods = new MethodCache(method.Name, method);
                Action<string> dummyWriter = cachedMethods.Add;
                WriteOneMethod(
                    dummyWriter,
                    methodReturnType,
                    method.Name,
                    methodGeneric,
                    methodParameters,
                    method.GetParameters().Select(p => p.Name).ToArray(),
                    methodReturnType != "void",
                    staticOnly,
                    genericWhereSpec);
                cachedMethods.Close();
                methods.Add(cachedMethods);
            }

            PostProcessMethods(methods);
        }

        private void PostProcessMethods(IEnumerable<MethodCache> methods)
        {
            var duplicates =
                methods.Where(method => methods.Count(m => m.MethodName == method.MethodName) > 1)
                    .OrderBy(method => method.MethodName);
            var others = methods.Where(method => methods.Count(m => m.MethodName == method.MethodName) == 1).ToList();

            // Select a duplicate method that is specifically declared in the target type, ie not a base class.
            others.AddRange(from duplicate in duplicates.Select(m => m.MethodName).Distinct()
                            select duplicates.Count(m => m.MethodMetadata.DeclaringType == this.targetType) into refinementCheck
                            where refinementCheck == 1
                            select duplicates.Single(m => m.MethodMetadata.DeclaringType == this.targetType));

            // Update the duplicate list
            duplicates = methods.Where(m => others.Count(m2 => m2.MethodName == m.MethodName) == 0).OrderBy(m => m);
            duplicates.ToList().ForEach(m => this.warningWriter(string.Format("Duplicate method found '{0}' - unable to choose a duplicate to include in Accessor", m.MethodName)));

            // Output the code
            others.ForEach(m => this.writer(m.Code));

            this.MethodCount += others.Count();
        }

        private void PostProcessProperties(IEnumerable<PropertyCache> properties)
        {
            var duplicates =
                properties.Where(method => properties.Count(m => m.MethodName == method.MethodName) > 1)
                    .OrderBy(method => method.MethodName);
            var others = properties.Where(method => properties.Count(m => m.MethodName == method.MethodName) == 1).ToList();

            // Select a duplicate method that is specifically declared in the target type, ie not a base class.
            others.AddRange(from duplicate in duplicates.Select(m => m.MethodName).Distinct()
                            select duplicates.Count(m => m.PropertyMetadata.DeclaringType == this.targetType) into refinementCheck
                            where refinementCheck == 1
                            select duplicates.Single(m => m.PropertyMetadata.DeclaringType == this.targetType));

            // Update the duplicate list
            duplicates = properties.Where(m => others.Count(m2 => m2.MethodName == m.MethodName) == 0).OrderBy(m => m);
            duplicates.ToList().ForEach(m => this.warningWriter(string.Format("Duplicate method found '{0}' - unable to choose a duplicate to include in Accessor", m.MethodName)));

            // Output the code
            others.ForEach(m => this.writer(m.Code));

            this.PropertyCount += others.Count();
        }

        private void WriteProperties(bool staticOnly)
        {
            var properties = new List<PropertyCache>();
            BindingFlags filter = BindingFlags.NonPublic | BindingFlags.Public;
            if (staticOnly)
            {
                filter |= BindingFlags.Static;
            } else
            {
                filter |= BindingFlags.Instance;

                // Write AccessingThisTarget property
                this.writer(string.Format(@"        public {0} AccessingThisTarget {{ get {{ return ({0})this.target; }} }}", PublicizableType(this.targetType)));
            }

            foreach (var property in this.targetType.GetProperties(filter)
                .Where(p => !p.Name.Contains("CachedAnonymous") && !MemberExcludeList.Contains(p.Name)))
            {
                string propertyReturnType;
                Type returnType = PublicizableTypeAsType(property.PropertyType);
                if (property.PropertyType.IsGenericType)
                {
                    propertyReturnType = new GenericTypeStringMaker(returnType).Digest();
                } else
                {
                    propertyReturnType = returnType.FullName;
                }

                var cacheProperty = new PropertyCache(property.Name, property);
                WriteOneProperty(cacheProperty.Add, propertyReturnType, property.Name, staticOnly, property.CanWrite, property.CanRead);
                cacheProperty.Close();
                properties.Add(cacheProperty);
            }

            PostProcessProperties(properties);
        }

        private void WriteFields()
        {
            foreach (var property in this.targetType.GetFields(BindingFlags.Instance | BindingFlags.NonPublic)
                .Where(field => !field.Name.EndsWith("BackingField") && !field.Name.Contains("CachedAnonymous")))
            {
                string typeName = GetFieldTypeName(property);
                FieldCount++;
                this.WriteOneField(typeName, property.Name);
            }
        }

        private void WriteOneField(string type, string property)
        {
            this.writer(
                string.Format(
@"
        public {0} {1} {{
            get {{
                try {{
                    return ({0})TargetType.GetField(""{1}"", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(this.target);
                }} catch (NullReferenceException ex) {{
                    throw new Exception(""The private accessor may be out of date, try regenerating its code. An Object Reference Not Set was thrown."", ex);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}

            set {{
                try {{
                    TargetType.GetField(""{1}"", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(this.target, value);
                }} catch (NullReferenceException ex) {{
                    throw new Exception(""The private accessor may be out of date, try regenerating its code. An Object Reference Not Set was thrown."", ex);
                }} catch (TargetInvocationException ex) {{
                    throw ex.InnerException ?? ex;
                }}
            }}
        }}",
                type,
                property));
        }

        private void WriteHeader()
        {
            this.writer("#pragma warning disable 465");

            if (this.isStaticClass)
            {
                this.writer(
                    string.Format(
@"
    [global::System.Diagnostics.DebuggerNonUserCodeAttribute()]
    [GeneratedCode(""ReesTestToolKit.PrivateAccessorGenerater"", ""1"")]
    public static class {0}_Accessor {{
        public static Type TargetType;",
                        this.targetType.Name));
            } else
            {
                this.writer(
                    string.Format(
@"
    [global::System.Diagnostics.DebuggerNonUserCodeAttribute()]
    [GeneratedCode(""ReesTestToolKit.PrivateAccessorGenerater"", ""1"")]
    public class {0}_Accessor {{
        private readonly object target;
        public static Type TargetType;
",
                          this.targetType.Name));
            }

            this.writer(
                string.Format(
                    @"        static {0}_Accessor() {{
            TargetType = Type.GetType(""{1}"");
        }}
",
                    this.targetType.Name,
                    this.targetType.AssemblyQualifiedName));
        }

        private void WriteFooter()
        {
            this.writer("    } // End Class " + this.targetType.Name);
            this.writer("// ============================================================================================================");
            this.writer("#pragma warning restore 465");
        }
    }
}

No comments:

Post a Comment