Skip to content

Commit

Permalink
Update pytorch.js (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 15, 2024
1 parent 8f2858e commit 8f8de68
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
77 changes: 71 additions & 6 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6154,12 +6154,21 @@ python.Execution = class {
kind() {
return this._kind;
}
__str__() {
throw new python.Error('Not implemented.');
}
toString() {
return this.__str__();
}
});
this.registerType('torch.ClassType', class extends torch.Type {
constructor(qualified_name, cu, is_module) {
super();
this._qualified_name = qualified_name;
this._is_module = is_module;
this._attributes = new Map();
this._methods = new Map();
this._staticmethods = new Map();
}
qualified_name() {
return this._qualified_name;
Expand All @@ -6170,11 +6179,23 @@ python.Execution = class {
is_module() {
return this._is_module;
}
addMethod(/* name, fn */) {
addMethod(func) {
this._methods.set(func.name, func);
}
addAttribute(/* name */) {
findMethod(name) {
return this._methods.get(name);
}
hasAttribute(/* name */) {
addStaticMethod(func) {
this._staticmethods.set(func.name, func);
}
findStaticMethod(name) {
return this._staticmethods.get(name);
}
addAttribute(name, type) {
this._attributes.set(name, type);
}
findAttribute(name) {
return this._attributes.get(name);
}
hasConstant(/* name */) {
}
Expand All @@ -6189,6 +6210,9 @@ python.Execution = class {
getElementType() {
return this._elem;
}
__str__() {
return `Optional[${this.getElementType().toString()}]`;
}
});
this.registerType('torch.ListType', class extends torch.Type {
constructor(elem, size) {
Expand All @@ -6201,6 +6225,9 @@ python.Execution = class {
getElementType() {
return this._elem;
}
__str__() {
return `List[${this.getElementType().toString()}]`;
}
});
this.registerType('torch.FutureType', class extends torch.Type {
constructor(elem, size) {
Expand All @@ -6213,16 +6240,32 @@ python.Execution = class {
}
});
this.registerType('torch.TupleType', class extends torch.Type {
constructor() {
constructor(elements) {
super('TupleType');
this._elements = elements;
}
elements() {
return this._elements;
}
});
this.registerType('torch.TensorType', class extends torch.Type {
constructor() {
super('TensorType');
}
__str__() {
return 'Tensor';
}
});
this.registerType('torch.AnyType', class extends torch.Type {
constructor() {
super('AnyType');
}
});
this.registerType('torch.NoneType', class extends torch.Type {
constructor() {
super('NoneType');
}
});
this.registerType('torch.AnyType', class extends torch.Type {});
this.registerType('torch.NumberType', class extends torch.Type {
constructor() {
super('NumberType');
Expand All @@ -6232,11 +6275,17 @@ python.Execution = class {
constructor() {
super('BoolType');
}
__str__() {
return 'bool';
}
});
this.registerType('torch.IntType', class extends torch.Type {
constructor() {
super('IntType');
}
__str__() {
return 'int';
}
});
this.registerType('torch.SymIntType', class extends torch.Type {
constructor() {
Expand All @@ -6247,16 +6296,25 @@ python.Execution = class {
constructor() {
super('FloatType');
}
__str__() {
return 'float';
}
});
this.registerType('torch.StringType', class extends torch.Type {
constructor() {
super('StringType');
}
__str__() {
return 'str';
}
});
this.registerType('torch.ComplexType', class extends torch.Type {
constructor() {
super('ComplexType');
}
__str__() {
return 'complex';
}
});
this.registerType('torch.DictType', class extends torch.Type {
constructor(key, value) {
Expand All @@ -6271,7 +6329,14 @@ python.Execution = class {
return this._value;
}
});
this.registerType('torch.DeviceObjType', class extends torch.Type {});
this.registerType('torch.DeviceObjType', class extends torch.Type {
constructor() {
super('DeviceObjType');
}
__str__() {
return 'Device';
}
});
this.registerType('torch._C._GeneratorType', class extends torch.Type {});
this.registerType('torch.Argument', class {
constructor(name, type, real_type, N, default_value, kwarg_only, alias_info) {
Expand Down
9 changes: 7 additions & 2 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,12 @@ def _argument(self, argument, value):
argument_type = '[' + size + ']' + argument_type
value = value.element_type
elif isinstance(value, Schema.DictType):
value = str(value)
name = value.getKeyType().name
key_type = self._primitives[name] if name in self._primitives else name
name = value.getValueType().name
value_type = self._primitives[name] if name in self._primitives else name
value = f'Dict({key_type}, {value_type})'
argument_type = value
else:
name = value.name
name = self._primitives[name] if name in self._primitives else name
Expand Down Expand Up @@ -498,7 +503,7 @@ def __init__(self, key_type, value_type):
self._key_type = key_type
self._value_type = value_type
def __str__(self):
return 'Dict[' + str(self._key_type) + ', ' + str(self._value_type) + ']'
return 'Dict(' + str(self._key_type) + ', ' + str(self._value_type) + ')'
def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
return self._key_type
def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
Expand Down

0 comments on commit 8f8de68

Please sign in to comment.