Skip to content

Commit

Permalink
Add an option to accept nil on attribute writer
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Apr 25, 2022
1 parent 8751baa commit 71cf127
Show file tree
Hide file tree
Showing 4 changed files with 636 additions and 133 deletions.
3 changes: 2 additions & 1 deletion exe/protoc-gen-rbs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ translator = case backend
input,
upcase_enum: upcase_enum,
nested_namespace: !no_nested_namespace,
extension: extension
extension: extension,
accept_nil_writer: false
)
when "google-protobuf"
raise NotImplementedError
Expand Down
217 changes: 112 additions & 105 deletions lib/rbs_protobuf/translator/protobuf_gem.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ class ProtobufGem < Base

attr_reader :stderr

def initialize(input, upcase_enum:, nested_namespace:, extension:, stderr: STDERR)
attr_reader :accept_nil_writer

def initialize(input, upcase_enum:, nested_namespace:, extension:, accept_nil_writer:, stderr: STDERR)
super(input)
@upcase_enum = upcase_enum
@nested_namespace = nested_namespace
@extension = extension
@accept_nil_writer = accept_nil_writer
@stderr = stderr
end

Expand Down Expand Up @@ -427,96 +430,109 @@ def message_field_hash_type(type, key)
end

def field_type(field, maps)
case
when field.type == FieldDescriptorProto::Type::TYPE_MESSAGE
if maps.key?(field.type_name)
key_field, value_field = maps[field.type_name]

key_type_r, _ = field_type(key_field, maps)
value_type_r, value_write_types = field_type(value_field, maps)

value_type_r = factory.unwrap_optional(value_type_r)
value_write_types = value_write_types.map {|type| factory.unwrap_optional(type) }
# @type var triple: [RBS::Types::t, Array[RBS::Types::t], RBS::Types::t]
triple =
case
when field.type == FieldDescriptorProto::Type::TYPE_MESSAGE
if maps.key?(field.type_name)
key_field, value_field = maps[field.type_name]

key_type_r, _ = field_type(key_field, maps)
value_type_r, value_write_types = field_type(value_field, maps)

value_type_r = factory.unwrap_optional(value_type_r)
value_write_types = value_write_types.map {|type| factory.unwrap_optional(type) }

case value_field.type
when FieldDescriptorProto::Type::TYPE_MESSAGE, FieldDescriptorProto::Type::TYPE_ENUM
value_type_r.is_a?(RBS::Types::ClassInstance) or raise
[
message_field_hash_type(value_type_r, key_type_r),
[message_hash_type(value_type_r, key_type_r)],
message_hash_type(value_type_r, key_type_r)
]
else
hash_type = FIELD_HASH[
key_type_r,
value_type_r,
factory.union_type(value_type_r, *value_write_types)
]

case value_field.type
when FieldDescriptorProto::Type::TYPE_MESSAGE, FieldDescriptorProto::Type::TYPE_ENUM
value_type_r.is_a?(RBS::Types::ClassInstance) or raise
[
message_field_hash_type(value_type_r, key_type_r),
[message_hash_type(value_type_r, key_type_r)],
message_hash_type(value_type_r, key_type_r)
]
[
FIELD_HASH_a[key_type_r, value_type_r],
[RBS::BuiltinNames::Hash.instance_type(key_type_r, value_type_r)],
RBS::BuiltinNames::Hash.instance_type(key_type_r, value_type_r)
]
end
else
hash_type = FIELD_HASH[
key_type_r,
value_type_r,
factory.union_type(value_type_r, *value_write_types)
]
type = message_type(field.type_name)

[
FIELD_HASH_a[key_type_r, value_type_r],
[RBS::BuiltinNames::Hash.instance_type(key_type_r, value_type_r)],
RBS::BuiltinNames::Hash.instance_type(key_type_r, value_type_r)
]
case field.label
when FieldDescriptorProto::Label::LABEL_OPTIONAL
[
factory.optional_type(type),
[
factory.optional_type(message_to_proto_type(type))
],
factory.optional_type(message_init_type(type))
]
when FieldDescriptorProto::Label::LABEL_REPEATED
[
message_field_array_type(type),
[
message_array_type(type)
],
message_array_type(type)
]
else
[
type,
[message_to_proto_type(type)],
message_init_type(type)
]
end
end
else
when field.type == FieldDescriptorProto::Type::TYPE_ENUM
type = message_type(field.type_name)
enum_namespace = type.name.to_namespace
values = factory.alias_type(RBS::TypeName.new(name: :values, namespace: enum_namespace))

case field.label
when FieldDescriptorProto::Label::LABEL_OPTIONAL
[
factory.optional_type(type),
[
factory.optional_type(message_to_proto_type(type))
],
factory.optional_type(message_init_type(type))
]
when FieldDescriptorProto::Label::LABEL_REPEATED
if field.label == FieldDescriptorProto::Label::LABEL_REPEATED
[
message_field_array_type(type),
[
message_array_type(type)
],
[message_array_type(type)],
message_array_type(type)
]
else
[
type,
[message_to_proto_type(type)],
[values],
message_init_type(type)
]
end
end
when field.type == FieldDescriptorProto::Type::TYPE_ENUM
type = message_type(field.type_name)
enum_namespace = type.name.to_namespace
values = factory.alias_type(RBS::TypeName.new(name: :values, namespace: enum_namespace))

if field.label == FieldDescriptorProto::Label::LABEL_REPEATED
[
message_field_array_type(type),
[message_array_type(type)],
message_array_type(type)
]
else
[
type,
[values],
message_init_type(type)
]
end
else
type = base_type(field.type)
type = base_type(field.type)

if field.label == FieldDescriptorProto::Label::LABEL_REPEATED
[
FIELD_ARRAY_a[type],
[RBS::BuiltinNames::Array.instance_type(type)],
RBS::BuiltinNames::Array.instance_type(type)
]
else
[type, [], type]
if field.label == FieldDescriptorProto::Label::LABEL_REPEATED
[
FIELD_ARRAY_a[type],
[RBS::BuiltinNames::Array.instance_type(type)],
RBS::BuiltinNames::Array.instance_type(type)
]
else
[type, [], type]
end
end

if accept_nil_writer
read_type, write_types, init_type = triple
[
read_type,
([factory.optional_type(read_type)] + write_types.map {|t| factory.optional_type(t) }).uniq,
factory.optional_type(init_type)
]
else
triple
end
end

Expand Down Expand Up @@ -548,40 +564,31 @@ def add_field(members, name:, read_type:, write_types:, comment:)
kind: :instance
)

write_types.each do |write_type|
if (type_param, type = interface_type?(write_type))
members << RBS::AST::Members::MethodDefinition.new(
name: :"#{name}=",
types: [
factory.method_type(
type: factory.function(type).update(
required_positionals:[factory.param(type)]
)
).update(type_params: [type_param])
],
annotations: [],
comment: comment,
location: nil,
overload: true,
kind: :instance
)
else
members << RBS::AST::Members::MethodDefinition.new(
name: :"#{name}=",
types: [
factory.method_type(
type: factory.function(write_type).update(
required_positionals:[factory.param(write_type)]
unless write_types.empty?
members << RBS::AST::Members::MethodDefinition.new(
name: :"#{name}=",
types:
write_types.map do |write_type|
if (type_param, type = interface_type?(write_type))
factory.method_type(
type: factory.function(type).update(
required_positionals:[factory.param(type)]
)
).update(type_params: [type_param])
else
factory.method_type(
type: factory.function(write_type).update(
required_positionals:[factory.param(write_type)]
)
)
)
],
annotations: [],
comment: comment,
location: nil,
overload: true,
kind: :instance
)
end
end
end,
annotations: [],
comment: comment,
location: nil,
overload: true,
kind: :instance
)
end

members << RBS::AST::Members::MethodDefinition.new(
Expand Down
5 changes: 4 additions & 1 deletion sig/rbs_protobuf/translator/protobuf_gem.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ module RBSProtobuf

attr_reader stderr: IO

attr_reader accept_nil_writer: bool

def initialize: (
untyped input,
upcase_enum: bool,
nested_namespace: bool,
extension: bool | :print | nil,
?stderr: IO
accept_nil_writer: bool,
?stderr: IO,
) -> void

@upcase_enum: bool
Expand Down
Loading

0 comments on commit 71cf127

Please sign in to comment.