Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cln-rpc/src/model.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions contrib/msggen/msggen/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def recurse(f: model.Field):
if isinstance(f, model.ArrayField):
self.visit(f.itemtype)
recurse(f.itemtype)
elif isinstance(f, model.UnionField):
for v in f.variants:
self.visit(v)
recurse(v)
elif isinstance(f, model.CompositeField):
for c in f.fields:
self.visit(c)
Expand Down
94 changes: 92 additions & 2 deletions contrib/msggen/msggen/gen/grpc/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# A grpc model
from msggen.model import ArrayField, CompositeField, EnumField, PrimitiveField, Service
from msggen.gen.grpc.util import notification_typename_overrides
from msggen.model import ArrayField, CompositeField, EnumField, PrimitiveField, UnionField, Service
from msggen.gen.grpc.util import notification_typename_overrides, camel_to_snake, union_variant_suffix
from msggen.gen.rpc.rust import union_variant_name
from msggen.gen import IGenerator
from typing import TextIO
from textwrap import indent, dedent
Expand All @@ -17,6 +18,86 @@ def generate_array(self, prefix, field: ArrayField, override):
if isinstance(field.itemtype, CompositeField):
self.generate_composite(prefix, field.itemtype, override)

def union_variant_conversion(self, f, val="v"):
"""Generate the conversion expression for a single union variant value."""
if isinstance(f, PrimitiveField):
mapping = {
"short_channel_id": f"{val}.to_string()",
"short_channel_id_dir": f"{val}.to_string()",
"pubkey": f"{val}.serialize().to_vec()",
"hex": f"hex::decode({val}).unwrap()",
"txid": f"hex::decode({val}).unwrap()",
"hash": f"<Sha256 as AsRef<[u8]>>::as_ref(&{val}).to_vec()",
"secret": f"{val}.to_vec()",
"msat": f"{val}.into()",
"msat_or_all": f"{val}.into()",
"msat_or_any": f"{val}.into()",
"sat": f"{val}.into()",
"sat_or_all": f"{val}.into()",
"feerate": f"{val}.into()",
"outpoint": f"{val}.into()",
}.get(f.typename, val)
return mapping
elif isinstance(f, ArrayField):
inner_mapping = {
"short_channel_id": "i.to_string()",
"short_channel_id_dir": "i.to_string()",
"pubkey": "i.serialize().to_vec()",
"hex": "hex::decode(i).unwrap()",
"txid": "hex::decode(i).unwrap()",
}.get(f.itemtype.typename, "i.into()")
return f"{val}.into_iter().map(|i| {inner_mapping}).collect()"
elif isinstance(f, EnumField):
return f"{val} as i32"
elif isinstance(f, CompositeField):
return f"{val}.into()"
return val

def generate_union(self, prefix, field: UnionField, parent_typename, override=None):
"""Generate From impl for a union type (cln-rpc enum -> pb oneof)."""
if override is None:
override = lambda x: x

typename = str(field.typename)
pbname = override(self.to_camel_case(str(override(parent_typename))))
pb_mod = camel_to_snake(pbname)
oneof_name = field.normalized()
# The prost enum name is CamelCase of the oneof field name
pb_oneof_enum = self.to_camel_case(oneof_name[0].upper() + oneof_name[1:])

self.write(
f"""\
impl From<{prefix}::{typename}> for pb::{pb_mod}::{pb_oneof_enum} {{
fn from(c: {prefix}::{typename}) -> Self {{
match c {{
"""
)

for v in field.variants:
vname = union_variant_name(v)
suffix = union_variant_suffix(v)
pb_variant = self.to_camel_case(f"{oneof_name}_{suffix}")
pb_variant = pb_variant[0].upper() + pb_variant[1:]
if isinstance(v, ArrayField):
wrapper_name = override(f"{parent_typename}{suffix}Wrapper")
wrapper_pb = self.to_camel_case(str(wrapper_name))
self.write(
f" {prefix}::{typename}::{vname}(v) => pb::{pb_mod}::{pb_oneof_enum}::{pb_variant}(pb::{wrapper_pb} {{ items: v.into_iter().map(|i| {self.union_variant_conversion(v.itemtype, 'i')}).collect() }}),\n"
)
else:
self.write(
f" {prefix}::{typename}::{vname}(v) => pb::{pb_mod}::{pb_oneof_enum}::{pb_variant}({self.union_variant_conversion(v)}),\n"
)

self.write(
f"""\
}}
}}
}}

"""
)

def generate_composite(self, prefix, field: CompositeField, override=None):
"""Generates the conversions from JSON-RPC to GRPC."""
if field.omit():
Expand All @@ -34,6 +115,8 @@ def generate_composite(self, prefix, field: CompositeField, override=None):
self.generate_array(prefix, f, override)
elif isinstance(f, CompositeField):
self.generate_composite(prefix, f, override)
elif isinstance(f, UnionField):
self.generate_union(prefix, f, str(field.typename), override)

pbname = override(self.to_camel_case(str(override(field.typename))))

Expand Down Expand Up @@ -148,6 +231,13 @@ def generate_composite(self, prefix, field: CompositeField, override=None):
else:
rhs = f"c.{name}.map(|v| v.into())"
self.write(f"{name}: {rhs},\n", numindent=3)

elif isinstance(f, UnionField):
if not f.optional:
self.write(f"{name}: Some(c.{name}.into()),\n", numindent=3)
else:
self.write(f"{name}: c.{name}.map(|v| v.into()),\n", numindent=3)

self.write(
f"""\
}}
Expand Down
57 changes: 56 additions & 1 deletion contrib/msggen/msggen/gen/grpc/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
typemap,
method_name_overrides,
notification_typename_overrides,
union_variant_suffix,
)
from msggen.model import (
ArrayField,
Field,
CompositeField,
EnumField,
PrimitiveField,
UnionField,
Service,
MethodName,
TypeName,
Expand Down Expand Up @@ -177,6 +179,32 @@ def generate_enum(self, e: EnumField, indent=0, typename_override=None):

self.write(f"""{prefix}}}\n""", False)

def union_variant_proto_type(self, f, parent_typename):
"""Return the protobuf type name for a union variant field."""
if isinstance(f, PrimitiveField):
return typemap.get(f.typename, f.typename)
elif isinstance(f, EnumField):
return str(f.typename)
elif isinstance(f, CompositeField):
return str(f.typename)
elif isinstance(f, ArrayField):
# oneof cannot contain repeated, so we need a wrapper message
return f"{parent_typename}{union_variant_suffix(f)}Wrapper"
return "bytes"

def generate_union_wrapper_messages(self, u: UnionField, parent_typename, typename_override=None):
"""Generate wrapper messages for array variants inside a union (oneof can't contain repeated)."""
if typename_override is None:
typename_override = lambda x: x

for v in u.variants:
if isinstance(v, ArrayField):
wrapper_name = f"{parent_typename}{union_variant_suffix(v)}Wrapper"
item_type = typemap.get(v.itemtype.typename, v.itemtype.typename)
self.write(f"\nmessage {typename_override(wrapper_name)} {{\n", False)
self.write(f"\trepeated {item_type} items = 1;\n", False)
self.write(f"}}\n", False)

def generate_message(self, message: CompositeField, typename_override=None):
if message.omit():
return
Expand All @@ -186,6 +214,11 @@ def generate_message(self, message: CompositeField, typename_override=None):
if typename_override is None:
typename_override = lambda x: x

# Generate wrapper messages for any union fields first
for f in message.fields:
if isinstance(f, UnionField):
self.generate_union_wrapper_messages(f, str(message.typename), typename_override)

self.write(
f"""
message {typename_override(message.typename)} {{
Expand All @@ -203,7 +236,26 @@ def generate_message(self, message: CompositeField, typename_override=None):

opt = "optional " if f.optional and not (isinstance(f, PrimitiveField) and f.typename == "string_map") else ""

if isinstance(f, ArrayField):
if isinstance(f, UnionField):
self.write(f"\toneof {f.normalized()} {{\n", False)
first_variant = True
for v in f.variants:
suffix = union_variant_suffix(v)
vname = f"{f.normalized()}_{suffix}"
vtype = self.union_variant_proto_type(v, str(message.typename))
vtype = typename_override(vtype)
if first_variant:
# Reuse the original field number for backward compat
vnum = i
first_variant = False
else:
# Allocate new numbers for additional variants
parent = ".".join(f.path.split(".")[:-1])
vfield = PrimitiveField(suffix, f"{parent}.{f.normalized()}_{suffix}", "", added=f.added, deprecated=f.deprecated)
vnum = self.field2number(message.typename, vfield)
self.write(f"\t\t{vtype} {vname} = {vnum};\n", False)
self.write(f"\t}}\n", False)
elif isinstance(f, ArrayField):
typename = f.override(
typemap.get(f.itemtype.typename, f.itemtype.typename)
)
Expand Down Expand Up @@ -264,5 +316,8 @@ def gather_subfields(field: Field) -> List[Field]:
elif isinstance(field, ArrayField):
fields = []
fields.extend(gather_subfields(field.itemtype))
elif isinstance(field, UnionField):
for v in field.variants:
fields.extend(gather_subfields(v))

return fields
85 changes: 84 additions & 1 deletion contrib/msggen/msggen/gen/grpc/unconvert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# A grpc model
from msggen.model import ArrayField, CompositeField, EnumField, PrimitiveField, Service
from msggen.model import ArrayField, CompositeField, EnumField, PrimitiveField, UnionField, Service
from msggen.gen.grpc.convert import GrpcConverterGenerator
from msggen.gen.grpc.util import camel_to_snake, union_variant_suffix
from msggen.gen.rpc.rust import union_variant_name


class GrpcUnconverterGenerator(GrpcConverterGenerator):
Expand All @@ -12,6 +14,79 @@ def generate(self, service: Service):
# TODO Temporarily disabled since the use of overrides is lossy
# self.generate_responses(service)

def unconvert_variant_value(self, f, val="v"):
"""Generate the reverse conversion expression for a union variant value (pb -> cln-rpc)."""
if isinstance(f, PrimitiveField):
return {
"short_channel_id": f"cln_rpc::primitives::ShortChannelId::from_str(&{val}).unwrap()",
"short_channel_id_dir": f"cln_rpc::primitives::ShortChannelIdDir::from_str(&{val}).unwrap()",
"pubkey": f"PublicKey::from_slice(&{val}).unwrap()",
"hex": f"hex::encode({val})",
"txid": f"hex::encode({val})",
"hash": f"Sha256::from_slice(&{val}).unwrap()",
"secret": f"{val}.try_into().unwrap()",
"msat": f"{val}.into()",
"msat_or_all": f"{val}.into()",
"msat_or_any": f"{val}.into()",
"sat": f"{val}.into()",
"sat_or_all": f"{val}.into()",
"feerate": f"{val}.into()",
"outpoint": f"{val}.into()",
}.get(f.typename, val)
elif isinstance(f, ArrayField):
inner_mapping = {
"short_channel_id": "cln_rpc::primitives::ShortChannelId::from_str(&s).unwrap()",
"short_channel_id_dir": "cln_rpc::primitives::ShortChannelIdDir::from_str(&s).unwrap()",
"pubkey": "PublicKey::from_slice(&s).unwrap()",
"hex": "hex::encode(s)",
"txid": "hex::encode(s)",
}.get(f.itemtype.typename, "s.into()")
return f"{val}.items.into_iter().map(|s| {inner_mapping}).collect()"
elif isinstance(f, EnumField):
return f"{val}.try_into().unwrap()"
elif isinstance(f, CompositeField):
return f"{val}.into()"
return val

def generate_union_unconvert(self, prefix, field: UnionField, parent_typename, override=None):
"""Generate From impl for pb oneof -> cln-rpc union type."""
if override is None:
override = lambda x: x

typename = str(field.typename)
pbname = self.to_camel_case(str(parent_typename))
pb_mod = camel_to_snake(pbname)
oneof_name = field.normalized()
pb_oneof_enum = self.to_camel_case(oneof_name[0].upper() + oneof_name[1:])

self.write(
f"""\
impl From<pb::{pb_mod}::{pb_oneof_enum}> for {prefix}::{typename} {{
fn from(c: pb::{pb_mod}::{pb_oneof_enum}) -> Self {{
match c {{
"""
)

for v in field.variants:
vname = union_variant_name(v)
suffix = union_variant_suffix(v)
pb_variant = self.to_camel_case(f"{oneof_name}_{suffix}")
pb_variant = pb_variant[0].upper() + pb_variant[1:]
conv = self.unconvert_variant_value(v)

self.write(
f" pb::{pb_mod}::{pb_oneof_enum}::{pb_variant}(v) => {prefix}::{typename}::{vname}({conv}),\n"
)

self.write(
f"""\
}}
}}
}}

"""
)

def generate_composite(self, prefix, field: CompositeField, override=None) -> None:
# First pass: generate any sub-fields before we generate the
# top-level field itself.
Expand All @@ -26,6 +101,8 @@ def generate_composite(self, prefix, field: CompositeField, override=None) -> No
self.generate_array(prefix, f, override)
elif isinstance(f, CompositeField):
self.generate_composite(prefix, f, override)
elif isinstance(f, UnionField):
self.generate_union_unconvert(prefix, f, str(field.typename), override)

has_deprecated = any([f.deprecated for f in field.fields])
deprecated = ",deprecated" if has_deprecated else ""
Expand Down Expand Up @@ -146,6 +223,12 @@ def generate_composite(self, prefix, field: CompositeField, override=None) -> No
rhs = f"c.{name}.map(|v| v.into())"
self.write(f"{name}: {rhs},\n", numindent=3)

elif isinstance(f, UnionField):
if not f.optional:
self.write(f"{name}: c.{name}.unwrap().into(),\n", numindent=3)
else:
self.write(f"{name}: c.{name}.map(|v| v.into()),\n", numindent=3)

self.write(
f"""\
}}
Expand Down
33 changes: 33 additions & 0 deletions contrib/msggen/msggen/gen/grpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,36 @@ def camel_to_snake(camel_case: str):
snake = re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case).lower()
snake = snake.replace("-", "")
return snake


def union_variant_suffix(f):
"""Generate a short suffix for a union variant field name in proto/grpc contexts."""
from msggen.model import ArrayField, CompositeField, EnumField, PrimitiveField
if isinstance(f, PrimitiveField):
return {
"boolean": "bool",
"string": "string",
"integer": "int",
"u32": "u32",
"u64": "u64",
"short_channel_id": "scid",
"short_channel_id_dir": "sciddir",
"msat": "msat",
"msat_or_all": "msat_or_all",
"msat_or_any": "msat_or_any",
"sat": "sat",
"sat_or_all": "sat_or_all",
"pubkey": "pubkey",
"hex": "hex",
"number": "number",
"feerate": "feerate",
"currency": "currency",
}.get(f.typename, f.typename)
elif isinstance(f, ArrayField):
inner = union_variant_suffix(f.itemtype)
return f"arr_{inner}"
elif isinstance(f, EnumField):
return str(f.typename).lower()
elif isinstance(f, CompositeField):
return str(f.typename).lower()
return "unknown"
Loading